Skip to content

Commit

Permalink
Vertically/horizontally stacked XRD plots (#223)
Browse files Browse the repository at this point in the history
* add ability to stack XRD plots vertically or horizontally

new xrd_pattern kwargs:
- stack: Literal["horizontal", "vertical"] | None = None
- subplot_kwargs: dict[str, Any] | None = None
- subtitle_kwargs: dict[str, Any] | None = None

* example XRD assets for horizontal/vertical stacking

* show stacked XRD plots in readme

* fix xrd_pattern doc string missing args types

* add unit tests for v/h stacked XRD plots and custom subplot titles
  • Loading branch information
janosh authored Oct 7, 2024
1 parent 82c087b commit c185585
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 71 deletions.
1 change: 1 addition & 0 deletions assets/xrd-pattern-horizontal-stack.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/xrd-pattern-vertical-stack.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 32 additions & 2 deletions examples/make_assets/xrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
lambda struct: f"{struct.formula} ({struct.get_space_group_info()[1]})"
)
structures = {
formula_spg_str(struct := Structure.from_file(file)): (struct)
formula_spg_str(struct := Structure.from_file(file)): struct
for file in glob(f"{TEST_FILES}/xrd/*.cif")
+ glob(f"{TEST_FILES}/structures/*.json.gz")
}
xrd_patterns = {
key: XRDCalculator().get_pattern(struct) for key, struct in structures.items()
}
key1, key2, *_ = xrd_patterns
key1, key2, key3, *_ = xrd_patterns


# %%
Expand All @@ -37,3 +38,32 @@
fig = pmv.xrd_pattern({key1: xrd_patterns[key1], key2: xrd_patterns[key2]})
fig.show()
pmv.io.save_and_compress_svg(fig, "xrd-pattern-multiple")


# %%
fig = pmv.xrd_pattern(
{key1: xrd_patterns[key1], key2: xrd_patterns[key2], key3: xrd_patterns[key3]},
stack="horizontal",
annotate_peaks=3,
show_angles=True,
)
fig.layout.title = dict(text="Horizontally Stacked XRD Patterns", x=0.5, y=0.97)
fig.layout.margin.t = 40
fig.show()
pmv.io.save_and_compress_svg(fig, "xrd-pattern-horizontal-stack")


# %% New example with vertical stacking and custom subplot titles
fig = pmv.xrd_pattern(
{f"{key1} {idx=}": structures[key1].copy().perturb(idx * 0.5) for idx in range(3)},
stack="vertical",
annotate_peaks=1,
show_angles=True,
subtitle_kwargs=dict(x=1, xanchor="right", font_size=14),
)
fig.layout.title = dict(
text="Vertically Stacked XRD Patterns with Custom Subplot Titles", x=0.5, y=0.97
)
fig.layout.margin.t = 40
fig.show()
pmv.io.save_and_compress_svg(fig, "xrd-pattern-vertical-stack")
16 changes: 7 additions & 9 deletions pymatviz/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,14 @@ def element_pair_rdfs(
n_rows = (n_pairs + actual_cols - 1) // actual_cols

# Create the plotly figure with facets
fig = make_subplots(
**dict(
rows=n_rows,
cols=actual_cols,
subplot_titles=[f"{el1}-{el2}" for el1, el2 in element_pairs],
vertical_spacing=0.15 / n_rows,
horizontal_spacing=0.15 / actual_cols,
)
| subplot_kwargs
subplot_defaults = dict(
rows=n_rows,
cols=actual_cols,
subplot_titles=[f"{el1}-{el2}" for el1, el2 in element_pairs],
vertical_spacing=0.15 / n_rows,
horizontal_spacing=0.15 / actual_cols,
)
fig = make_subplots(**subplot_defaults | subplot_kwargs)

# Set default colors and line styles if not provided
line_styles = line_styles or "solid dot dash longdash dashdot longdashdot".split()
Expand Down
16 changes: 8 additions & 8 deletions pymatviz/structure_viz/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,12 @@ def draw_site(
site_labels, site_idx, major_elem_symbol, majority_species
)

marker = (
dict(
size=site_radius * atom_size * (0.8 if is_image else 1),
color=color,
opacity=0.5 if is_image else 1,
)
| site_kwargs
marker = dict(
size=site_radius * atom_size * (0.8 if is_image else 1),
color=color,
opacity=0.5 if is_image else 1,
)
marker.update(site_kwargs)

scatter_kwargs = dict(
x=[coords[0]],
Expand Down Expand Up @@ -530,7 +528,9 @@ def get_first_matching_site_prop(
warn_if_none: bool = True,
filter_callback: Callable[[str, Any], bool] | None = None,
) -> str | None:
"""Find the first property key that exists in any structure or site properties.
"""Find the first property key that exists in any of the passed structures'
properties or site properties. Will look in site.properties first, then
structure.properties.
Args:
structures (Sequence[Structure]): pymatgen Structures to check.
Expand Down
158 changes: 118 additions & 40 deletions pymatviz/xrd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pymatgen.analysis.diffraction.xrd import DiffractionPattern, XRDCalculator
from pymatgen.core import Structure

Expand Down Expand Up @@ -38,7 +39,7 @@ def format_hkl(hkl: tuple[int, int, int], format_type: HklFormat) -> str:
raise ValueError(f"{format_type=} must be one of {ValidHklFormats}")


def xrd_pattern(
def xrd_pattern( # noqa: D417
patterns: PatternOrStruct
| dict[str, PatternOrStruct | tuple[PatternOrStruct, dict[str, Any]]],
*,
Expand All @@ -47,26 +48,38 @@ def xrd_pattern(
hkl_format: HklFormat = HklCompact,
show_angles: bool | None = None,
wavelength: float = 1.54184, # Cu K-alpha wavelength
stack: Literal["horizontal", "vertical"] | None = None,
subplot_kwargs: dict[str, Any] | None = None,
subtitle_kwargs: dict[str, Any] | None = None,
) -> go.Figure:
"""Create a plotly figure of XRD patterns from DiffractionPattern, Structure
objects, or a dictionary of them.
"""Create a plotly figure of XRD patterns from a pymatgen DiffractionPattern,
from a pymatgen Structure, or a dictionary of either of them.
Args:
patterns: Either a single DiffractionPattern or Structure object, or a
dictionary where keys are legend labels
patterns (PatternOrStruct | dict[str, PatternOrStruct | tuple[PatternOrStruct,
dict[str, Any]]]): Either a single DiffractionPattern or Structure object,
or a dictionary where keys are legend labels
and values are either DiffractionPattern/Structure objects or tuples of
(DiffractionPattern/Structure, kwargs) for customizing individual patterns.
peak_width: Width of the diffraction peaks in degrees. Default is 0.5.
annotate_peaks: Controls peak annotation. If int, annotates that many highest
peaks. If float, should be in (0, 1) which will annotate peaks higher than
that fraction of the highest peak. Default is 5.
hkl_format: Format for hkl indices. One of 'compact' ('100'), 'full'
('(1, 0, 0)'), or None for no hkl indices. Default is 'compact' for 3 or
fewer patterns, None for 4 or more patterns.
show_angles: Whether to show angles in peak annotations. If None, it will
default to True if plotting 1 or 2 patterns, False for 3 or more patterns.
wavelength: X-ray wavelength for the XRD calculation (in Angstroms). Default is
1.54184 (Cu K-alpha). Only used if patterns contains Structure objects.
peak_width (float): Width of the diffraction peaks in degrees. Default is 0.5.
annotate_peaks (float): Controls peak annotation. If int, annotates that many
highest peaks. If float, should be in (0, 1) which will annotate peaks
higher than that fraction of the highest peak. Default is 5.
hkl_format (HklFormat): Format for hkl indices. One of 'compact' (ex: '100'),
'full' (ex: '(1, 0, 0)'), or None for no hkl indices. Default is 'compact'.
show_angles (bool | None): Whether to show angles in peak annotations. If None,
it will default to True if plotting 1 or 2 patterns, False for 3 or more
patterns.
wavelength (float): X-ray wavelength for the XRD calculation (in Angstroms).
Default is 1.54184 (Cu K-alpha). Only used if patterns argument contains
Structures.
stack (Literal["horizontal", "vertical"] | None): If set to "horizontal" or
"vertical", creates separate subplots for each pattern. Default is None
(all patterns in one plot).
subplot_kwargs (dict[str, Any] | None): Passed to make_subplots. Can be used to
control spacing between subplots, e.g. {'vertical_spacing': 0.02}.
subtitle_kwargs (dict[str, Any] | None): Override default subplot title
settings. E.g. dict(font_size=14). Default is None.
Raises:
ValueError: If annotate_peaks is not a positive int or a float in (0, 1).
Expand All @@ -84,15 +97,6 @@ def xrd_pattern(
f"{annotate_peaks=} should be a positive int or a float in (0, 1)"
)

layout = dict(
xaxis=dict(title="2θ (degrees)", tickmode="linear", tick0=0, dtick=10),
yaxis=dict(title="Intensity (a.u.)", range=[0, 105]),
hovermode="x",
barmode="overlay",
)
fig = go.Figure(layout=layout)
max_intensity = max_two_theta = 0

# Convert single object to dict for uniform processing
if isinstance(patterns, DiffractionPattern | Structure):
patterns = {"XRD Pattern": patterns}
Expand All @@ -105,8 +109,28 @@ def xrd_pattern(
if show_angles is None:
show_angles = len(patterns) <= 2

n_patterns = len(patterns)
if stack:
rows, cols = (n_patterns, 1) if stack == "vertical" else (1, n_patterns)
subplot_defaults = dict(
rows=rows,
cols=cols,
shared_xaxes=True,
shared_yaxes=True,
horizontal_spacing=0.05 / cols,
vertical_spacing=0.05 / rows,
)
fig = make_subplots(**subplot_defaults | (subplot_kwargs or {}))
# increase peak width for horizontal stacking
if stack == "horizontal":
peak_width *= 3
else:
fig = go.Figure()

max_intensity = max_two_theta = 0
plotted_patterns: list[DiffractionPattern] = []
for label, pattern_data in patterns.items():

for trace_idx, (label, pattern_data) in enumerate(patterns.items()):
if isinstance(pattern_data, tuple):
pattern_or_struct, trace_kwargs = pattern_data
else:
Expand All @@ -123,28 +147,26 @@ def xrd_pattern(
f"{value=} should be a pymatgen Structure or DiffractionPattern"
)

plotted_patterns += [diffraction_pattern]
two_theta = diffraction_pattern.x
intensities = diffraction_pattern.y
hkls = diffraction_pattern.hkls
d_hkls = diffraction_pattern.d_hkls
plotted_patterns.append(diffraction_pattern)
two_theta, intensities = diffraction_pattern.x, diffraction_pattern.y
hkls, d_hkls = diffraction_pattern.hkls, diffraction_pattern.d_hkls

if intensities is None or len(intensities) == 0:
raise ValueError(
f"No intensities found in the diffraction pattern for {label}"
)

# Update max intensity and two_theta across all patterns
# get max intensity and two_theta across all patterns
max_intensity = max(max_intensity, *intensities)
max_two_theta = max(max_two_theta, *two_theta)

tooltips = [
f"<b>{label}</b><br>"
f"2θ: {x:.2f}°<br>Intensity: {y:.2f}<br>hkl: "
f"<b>{label}</b><br>2θ: {x:.2f}°<br>Intensity: {y:.2f}<br>hkl: "
f"{'<br>'.join(format_hkl(h['hkl'], HklFull) for h in hkl)}<br>d: {d:.3f} Å"
for x, y, hkl, d in zip(two_theta, intensities, hkls, d_hkls, strict=True)
]
fig.add_bar(

bar = go.Bar(
x=two_theta,
y=intensities,
width=peak_width,
Expand All @@ -154,6 +176,13 @@ def xrd_pattern(
**trace_kwargs,
)

if stack:
row = trace_idx + 1 if stack == "vertical" else 1
col = trace_idx + 1 if stack == "horizontal" else 1
fig.add_trace(bar, row=row, col=col)
else:
fig.add_trace(bar)

# Normalize intensities to 100 and add annotations
for trace_idx, trace in enumerate(fig.data):
trace.y = [y / max_intensity * 100 for y in trace.y]
Expand All @@ -170,8 +199,7 @@ def xrd_pattern(
peak_indices = []

for idx in peak_indices:
x_pos = trace.x[idx]
y_pos = trace.y[idx]
x_pos, y_pos = trace.x[idx], trace.y[idx]

if hkl_format:
hkl_formatted = "<br>".join(
Expand All @@ -195,7 +223,7 @@ def xrd_pattern(
elif y_pos > 90:
ay = abs(ay)

fig.add_annotation(
anno_kwargs = dict(
x=x_pos,
y=y_pos,
text=annotation_text,
Expand All @@ -208,7 +236,57 @@ def xrd_pattern(
yanchor="bottom" if ay < 0 else "top",
)

fig.layout.xaxis.range = [0, max_two_theta + 5]
fig.layout.legend.update(x=1, y=1, xanchor="right", yanchor="top")
if stack:
row = trace_idx + 1 if stack == "vertical" else 1
col = trace_idx + 1 if stack == "horizontal" else 1
fig.add_annotation(row=row, col=col, **anno_kwargs)
else:
fig.add_annotation(**anno_kwargs)

if stack:
# Add trace name annotation at the top of each subplot
row = trace_idx + 1 if stack == "vertical" else 1
col = trace_idx + 1 if stack == "horizontal" else 1
subtitle_defaults = dict(
x=0,
y=1,
showarrow=False,
font=dict(size=12),
xanchor="left",
yanchor="top",
)

fig.add_annotation(
text=trace.name,
xref=f"x{trace_idx + 1} domain".replace("x1 ", "x "),
yref=f"y{trace_idx + 1} domain".replace("y1 ", "y "),
row=row,
col=col,
**subtitle_defaults | (subtitle_kwargs or {}),
)

# Update layout
fig.update_layout(
xaxis=dict(
title="2θ (degrees)",
tickmode="linear",
tick0=0,
dtick=10,
range=[0, max_two_theta + 5],
),
yaxis=dict(title="Intensity (a.u.)", range=[0, 105]),
hovermode="x",
barmode="overlay",
showlegend=stack is None,
legend=dict(x=1, y=1, xanchor="right", yanchor="top"),
)

# move tick marks inside
fig.update_xaxes(ticks="inside")
fig.update_yaxes(ticks="inside")

if stack:
fig.update_xaxes(matches="x")
fig.update_yaxes(matches="y")

return fig
10 changes: 7 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,16 @@ See [`pymatviz/scatter.py`](pymatviz/scatter.py).

See [`pymatviz/xrd.py`](pymatviz/xrd.py).

| [`xrd_pattern(pattern)`](pymatviz/xrd.py) | [`xrd_pattern({key1: patt1, key2: patt2})`](pymatviz/xrd.py) |
| :---------------------------------------: | :----------------------------------------------------------: |
| ![xrd-pattern] | ![xrd-pattern-multiple] |
| [`xrd_pattern(pattern)`](pymatviz/xrd.py) | [`xrd_pattern({key1: patt1, key2: patt2})`](pymatviz/xrd.py) |
| :---------------------------------------------------------------: | :-----------------------------------------------------------------------------------: |
| ![xrd-pattern] | ![xrd-pattern-multiple] |
| [`xrd_pattern(struct_dict, stack="horizontal")`](pymatviz/xrd.py) | [`xrd_pattern(struct_dict, stack="vertical", title="Custom Title")`](pymatviz/xrd.py) |
| ![xrd-pattern-horizontal-stack] | ![xrd-pattern-vertical-stack] |

[xrd-pattern]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern.svg
[xrd-pattern-multiple]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-multiple.svg
[xrd-pattern-horizontal-stack]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-horizontal-stack.svg
[xrd-pattern-vertical-stack]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-vertical-stack.svg

## Radial Distribution Functions

Expand Down
Loading

0 comments on commit c185585

Please sign in to comment.