Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
✨ Add method to add data to existing model
New functions allow adding additional corpus data to an existing model without having to resupply the entire corpus.
  • Loading branch information
andrlik committed Apr 29, 2024
commit 21b0bfd0f23911caac2141606f3e457af0f89e25
115 changes: 105 additions & 10 deletions src/django_markov/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def _get_default_state_size() -> int:
STATE_SIZE = _get_default_state_size()


def _get_default_compile_setting() -> bool:
"""Get the default value from settings."""
if not hasattr(settings, "MARKOV_STORE_COMPILED_MODELS") or not isinstance(
settings.MARKOV_STORE_COMPILED_MODELS, bool
):
return False
return settings.MARKOV_STORE_COMPILED_MODELS


class MarkovTextModel(models.Model):
"""Stores a compiled markov text model.

Expand Down Expand Up @@ -148,15 +157,99 @@ def _compiled_model(self) -> POSifiedText | None:
return text_model
return text_model.compile(inplace=True) # type: ignore

async def aadd_new_corpus_data_to_model(
self,
corpus_entries: list[str],
*,
char_limit: int | None = None,
weights: list[float] | None = None,
) -> None:
"""Takes a list of new corpus entries and incorporates them into the model.
Unlike `aupdate_model_from_corpus`, this method is additive. This works by
first creating a text model based on the new entries, and then uses
`markovify.combine` to add them to the existing text model. Note that
this will fail if the stored model is compiled.

Args:
corpus_entries (list[str]): A list of text sentences to add.
char_limit (int | None): The character limit to use for the new corpus.
Use `0` for no limit.
weights (list[float] | None): The weighting to use for combine
operation, the first value representing the saved model, and the second
representing the new entries.

Raises:
MarkovCombineError: If the stored model is already compiled.
MarkovEmptyError: If the new models are empty.
"""
saved_model = self._as_text_model()
if self.data is None or self.data == "" or saved_model is None:
# There's no existing model, use update instead.
return await self.aupdate_model_from_corpus(
corpus_entries=corpus_entries, char_limit=char_limit
)
if char_limit is None:
char_limit = _get_corpus_char_limit()
if weights is not None and len(weights) != 2: # noqa: PLR2004
msg = "If provided, weights must have exactly two entries!"
raise ValueError(msg)
corpus = " ".join(corpus_entries)
if len(corpus_entries) == 0 or corpus.replace(" ", "") == "":
msg = "There are no corpus entries to add!"
raise MarkovEmptyError(msg)
if saved_model.chain.compiled:
msg = "Saved model is compiled, cannot combine!"
raise MarkovCombineError(msg)
new_model = POSifiedText(corpus, state_size=saved_model.state_size)
try:
combined_model = markovify.combine(
[saved_model, new_model], weights=weights
)
except ValueError as ve: # no cov
# If markovify raises any other unexpected error.
msg = f"The following error occurred while combining: {ve}"
raise MarkovCombineError(msg) from ve
if (
combined_model is not None and type(combined_model) is POSifiedText
): # no cov
self.data = combined_model.to_json()
await self.asave()

def add_new_corpus_data_to_model(
self,
corpus_entries: list[str],
*,
char_limit: int | None = None,
weights: list[float] | None = None,
) -> None:
"""Sync wrapper for `aadd_new_corpus_data_to_model`.

Args:
corpus_entries (list[str]): A list of text sentences to add.
char_limit (int | None): The character limit to use for the new corpus.
Use `0` for no limit.
weights (list[float] | None): The weighting to use for combine
operation, the first value representing the saved model, and the second
representing the new entries.

Raises:
MarkovCombineError: If the stored model is already compiled.
MarkovEmptyError: If the new models are empty.
ValueError: If weights are supplied, and they do not have a length of two.
"""
return async_to_sync(self.aadd_new_corpus_data_to_model)(
corpus_entries=corpus_entries, char_limit=char_limit, weights=weights
)

async def aupdate_model_from_corpus(
self,
corpus_entries: list[str],
*,
char_limit: int | None = None,
store_compiled: bool | None = None,
) -> None:
"""Takes the corpus and updates the model, saving it.
The corpus must not exceed the char_limit.
"""Takes the a list of entries as the new full corpus and recreates the model,
saving it. The corpus must not exceed the char_limit.

Args:
corpus_entries (list[str]): The corpus as a list of text sentences.
Expand All @@ -165,17 +258,14 @@ async def aupdate_model_from_corpus(
store_compiled (bool | None): Whether to store the model in it's compiled
state. If None, defaults to settings.MARKOV_STORE_COMPILED_MODELS or
False.

Raises:
ValueError: If the corpus is beyond the maximum character limit.
"""
if not char_limit:
char_limit = _get_corpus_char_limit()
if (
store_compiled is None
and hasattr(settings, "MARKOV_STORE_COMPILED_MODELS")
and isinstance(settings.MARKOV_STORE_COMPILED_MODELS, bool)
):
store_compiled = settings.MARKOV_STORE_COMPILED_MODELS
else:
store_compiled = False
if store_compiled is None:
store_compiled = _get_default_compile_setting()
corpus = " ".join(corpus_entries)
if char_limit != 0 and char_limit < len(corpus):
msg = f"Supplied corpus is over the maximum character limit: {char_limit}"
Expand Down Expand Up @@ -274,6 +364,11 @@ async def acombine_models(
Either a new MarkovTextModel instance
persisted to the database or a POSifiedText object to manipulate at a
low level, and the total number of models combined.

Raises:
ValueError: If any of the parameter combinations is invalid
MarkovCombineError: If models are incompatible for combining or a markovify
error is raised.
"""
# First we check to ensure that the models are combinable.
empty_models = []
Expand Down
91 changes: 91 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MarkovEmptyError,
MarkovTextModel,
_get_corpus_char_limit,
_get_default_compile_setting,
_get_default_state_size,
)
from django_markov.text_models import POSifiedText
Expand All @@ -43,6 +44,17 @@ def test_get_char_limit_missing_settings(settings):
assert _get_corpus_char_limit() == 0 # Setting was not present


@pytest.mark.parametrize("override_value", [False, True])
def test_get_compile_default_setting(settings, override_value):
settings.MARKOV_STORE_COMPILED_MODELS = override_value
assert _get_default_compile_setting() == override_value


def test_get_compile_default_missing_settings(settings):
del settings.MARKOV_STORE_COMPILED_MODELS
assert not _get_default_compile_setting()


@pytest.mark.parametrize(
"override_value,expected_result",
[
Expand Down Expand Up @@ -284,3 +296,82 @@ async def test_acombine_successful(
)
assert isinstance(result, expected_result_type)
assert total_combined == num_clean


def test_add_data_to_compiled_model_raises_exception(
compiled_model, sample_corpus
) -> None:
old_modify = compiled_model.modified
old_data = compiled_model.data
with pytest.raises(MarkovCombineError):
compiled_model.add_new_corpus_data_to_model(
[sample_corpus, "This is not going to work."]
)
compiled_model.refresh_from_db()
assert compiled_model.modified == old_modify
assert compiled_model.data == old_data


@pytest.mark.parametrize(
"corpus_entries,char_limit,weights,expected_exception",
[
([], None, None, MarkovEmptyError),
([], None, [1.0, 1.0], MarkovEmptyError),
(["I like springtime.", "Does this bring joy?"], None, [1.0], ValueError),
(["I like springtime.", "Does this bring joy?"], 0, [], ValueError),
(
["I like springtime.", "Does this bring joy?"],
None,
[1.0, 1.3, 1.0],
ValueError,
),
],
)
def test_add_data_to_model_invocation_failures(
text_model, sample_corpus, corpus_entries, char_limit, weights, expected_exception
):
text_model.update_model_from_corpus([sample_corpus], store_compiled=False)
text_model.refresh_from_db()
old_data = text_model.data
old_modify = text_model.modified
with pytest.raises(expected_exception):
text_model.add_new_corpus_data_to_model(
corpus_entries=corpus_entries, weights=weights
)
text_model.refresh_from_db()
assert text_model.modified == old_modify
assert text_model.data == old_data


@pytest.mark.parametrize(
"corpus_entries,char_limit,weights",
[
(["I like springtime.", "Does this bring joy?"], None, None),
(["I like springtime.", "Does this bring joy?"], 0, None),
(["I like springtime.", "Does this bring joy?"], None, [1.0, 1.0]),
],
)
def test_add_data_to_model_success(
text_model, sample_corpus, corpus_entries, char_limit, weights
):
text_model.update_model_from_corpus([sample_corpus], store_compiled=False)
text_model.refresh_from_db()
old_data = text_model.data
old_modify = text_model.modified
text_model.add_new_corpus_data_to_model(
corpus_entries=corpus_entries, char_limit=char_limit, weights=weights
)
text_model.refresh_from_db()
assert text_model.data != old_data
assert text_model.modified > old_modify


def test_add_data_to_empty_model_falls_back_to_update(text_model):
assert not text_model.data
old_modify = text_model.modified
text_model.add_new_corpus_data_to_model(
corpus_entries=["I like springtime.", "Does this bring joy?"]
)
text_model.refresh_from_db()
assert text_model.data is not None
assert text_model.modified > old_modify