Skip to content

Commit 5644ddc

Browse files
authored
Merge pull request #28 from CogStack/iob
feat: Support IOB and IOBES taggings for HF NER models
2 parents 40cb37d + 0aafd88 commit 5644ddc

File tree

15 files changed

+1201
-107
lines changed

15 files changed

+1201
-107
lines changed

app/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Settings(BaseSettings): # type: ignore
3434
TRAINING_METRICS_LOGGING_INTERVAL: int = 5 # the number of steps after which training metrics will be collected
3535
TRAINING_SAFE_MODEL_SERIALISATION: str = "false" # if "true", serialise the trained model using safe tensors
3636
TRAINING_CACHE_DIR: str = os.path.join(os.path.abspath(os.path.dirname(__file__)), "cms_cache") # the directory to cache the intermediate files created during training
37+
TRAINING_HF_TAGGING_SCHEME: str = "flat" # the tagging scheme during the Hugging Face NER model training, either "flat", "iob" or "iobes"
3738
HF_PIPELINE_AGGREGATION_STRATEGY: str = "simple" # the strategy used for aggregating the predictions of the Hugging Face NER model
3839
LOG_PER_CONCEPT_ACCURACIES: str = "false" # if "true", per-concept accuracies will be exposed to the metrics scrapper. Switch this on with caution due to the potentially high number of concepts
3940
MEDCAT2_MAPPED_ONTOLOGIES: str = "" # the comma-separated names of ontologies for MedCAT2 to map to

app/domain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ class Device(str, Enum):
7777
MPS = "mps"
7878

7979

80+
class TaggingScheme(str, Enum):
81+
FLAT = "flat"
82+
IOB = "iob"
83+
IOBES = "iobes"
84+
85+
8086
class HfTransformerBackbone(Enum):
8187
ALBERT = "albert"
8288
BIG_BIRD = "bert"
@@ -110,20 +116,24 @@ class LlmEngine(Enum):
110116
CMS = "CMS"
111117
VLLM = "vLLM"
112118

119+
113120
class LlmRole(Enum):
114121
SYSTEM = "system"
115122
USER = "user"
116123
ASSISTANT = "assistant"
117124
TOOL = "tool"
118125

126+
119127
class LlmTrainerType(Enum):
120128
GRPO = "grpo"
121129
PPO = "ppo"
122130

131+
123132
class LlmDatasetType(Enum):
124133
JSON = "json"
125134
CSV = "csv"
126135

136+
127137
class Annotation(BaseModel):
128138
doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs")
129139
start: int = Field(description="The start index of the annotation span")

app/envs/.env

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ TRAINING_SAFE_MODEL_SERIALISATION=false
7373
# The strategy used for aggregating the predictions of the Hugging Face NER model
7474
HF_PIPELINE_AGGREGATION_STRATEGY=simple
7575

76+
# The tagging scheme during the Hugging Face NER model training, either "flat", "iob" or "iobes"
77+
TRAINING_HF_TAGGING_SCHEME=flat
78+
7679
# The comma-separated names of ontologies for MedCAT2 to map to
7780
MEDCAT2_MAPPED_ONTOLOGIES=opcs4,icd10
7881

app/model_services/huggingface_llm_model.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from app.exception import ConfigurationException
1717
from app.model_services.base import AbstractModelService
1818
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
19-
from app.domain import ModelCard, ModelType, Annotation
19+
from app.domain import ModelCard, ModelType, Annotation, Device
2020
from app.config import Settings
2121
from app.utils import (
2222
get_settings,
@@ -157,9 +157,19 @@ def load_model(
157157
bnb_4bit_compute_dtype=torch.bfloat16,
158158
bnb_4bit_use_double_quant=True,
159159
)
160-
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config)
160+
if get_settings().DEVICE == Device.DEFAULT.value:
161+
model = AutoModelForCausalLM.from_pretrained(
162+
model_path,
163+
quantization_config=bnb_config,
164+
device_map="auto",
165+
)
166+
else:
167+
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config)
161168
else:
162-
model = AutoModelForCausalLM.from_pretrained(model_path)
169+
if get_settings().DEVICE == Device.DEFAULT.value:
170+
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
171+
else:
172+
model = AutoModelForCausalLM.from_pretrained(model_path)
163173
ensure_tensor_contiguity(model)
164174
tokenizer = AutoTokenizer.from_pretrained(
165175
model_path,
@@ -242,8 +252,7 @@ def generate(
242252
self.model.eval()
243253

244254
inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
245-
if non_default_device_is_available(self._config.DEVICE):
246-
inputs.to(get_settings().DEVICE)
255+
inputs.to(self.model.device)
247256

248257
generation_kwargs = dict(
249258
inputs=inputs.input_ids,
@@ -291,8 +300,7 @@ async def generate_async(
291300
self.model.eval()
292301

293302
inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
294-
if non_default_device_is_available(self._config.DEVICE):
295-
inputs.to(get_settings().DEVICE)
303+
inputs.to(self.model.device)
296304

297305
streamer = TextIteratorStreamer(
298306
self.tokenizer,
@@ -363,8 +371,7 @@ def create_embeddings(
363371
truncation=True,
364372
)
365373

366-
if non_default_device_is_available(self._config.DEVICE):
367-
inputs.to(get_settings().DEVICE)
374+
inputs.to(self.model.device)
368375

369376
with torch.no_grad():
370377
outputs = self.model(**inputs, output_hidden_states=True)

app/model_services/huggingface_ner_model.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from app.exception import ConfigurationException
1717
from app.model_services.base import AbstractModelService
1818
from app.trainers.huggingface_ner_trainer import HuggingFaceNerUnsupervisedTrainer, HuggingFaceNerSupervisedTrainer
19-
from app.domain import ModelCard, ModelType, Annotation
19+
from app.domain import ModelCard, ModelType, Annotation, Device, TaggingScheme
2020
from app.config import Settings
2121
from app.utils import (
2222
get_settings,
@@ -27,6 +27,7 @@
2727
get_model_data_package_base_name,
2828
load_pydantic_object_from_dict,
2929
)
30+
from app.processors.tagging import TagProcessor
3031

3132
logger = logging.getLogger("cms")
3233

@@ -41,7 +42,7 @@ def __init__(
4142
enable_trainer: Optional[bool] = None,
4243
model_name: Optional[str] = None,
4344
base_model_file: Optional[str] = None,
44-
confidence_threshold: float = 0.5,
45+
confidence_threshold: float = 0.7,
4546
) -> None:
4647
"""
4748
Initialises the HuggingFace NER model service with specified configurations.
@@ -52,7 +53,7 @@ def __init__(
5253
enable_trainer (Optional[bool]): The flag to enable or disable trainers. Defaults to None.
5354
model_name (Optional[str]): The name of the model. Defaults to None.
5455
base_model_file (Optional[str]): The model package file name. Defaults to None.
55-
confidence_threshold (float): The threshold for the confidence score. Defaults to 0.5.
56+
confidence_threshold (float): The threshold for the confidence score. Defaults to 0.7.
5657
"""
5758

5859
super().__init__(config)
@@ -123,19 +124,20 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase)
123124
HuggingFaceNerModel: A HuggingFace NER model service.
124125
"""
125126

126-
model_service = cls(get_settings(), enable_trainer=False)
127+
_config = get_settings()
128+
model_service = cls(_config, enable_trainer=False)
127129
model_service.model = model
128130
model_service.tokenizer = tokenizer
129131
_pipeline = partial(
130132
pipeline,
131133
task="ner",
132134
model=model_service.model,
133135
tokenizer=model_service.tokenizer,
134-
stride=10,
135-
aggregation_strategy=get_settings().HF_PIPELINE_AGGREGATION_STRATEGY,
136+
stride=32,
137+
aggregation_strategy=_config.HF_PIPELINE_AGGREGATION_STRATEGY,
136138
)
137-
if non_default_device_is_available(get_settings().DEVICE):
138-
model_service._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(get_settings().DEVICE))
139+
if non_default_device_is_available(_config.DEVICE):
140+
model_service._ner_pipeline = _pipeline(device=get_hf_pipeline_device_id(_config.DEVICE))
139141
else:
140142
model_service._ner_pipeline = _pipeline()
141143
return model_service
@@ -160,7 +162,10 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
160162
model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path))
161163
if unpack_model_data_package(model_file_path, model_path):
162164
try:
163-
model = AutoModelForTokenClassification.from_pretrained(model_path)
165+
if get_settings().DEVICE == Device.DEFAULT.value:
166+
model = AutoModelForTokenClassification.from_pretrained(model_path, device_map="auto")
167+
else:
168+
model = AutoModelForTokenClassification.from_pretrained(model_path)
164169
ensure_tensor_contiguity(model)
165170
tokenizer = AutoTokenizer.from_pretrained(
166171
model_path,
@@ -197,7 +202,7 @@ def init_model(self, *args: Any, **kwargs: Any) -> None:
197202
task="ner",
198203
model=self._model,
199204
tokenizer=self._tokenizer,
200-
stride=10,
205+
stride=32,
201206
aggregation_strategy=self._config.HF_PIPELINE_AGGREGATION_STRATEGY,
202207
)
203208
if non_default_device_is_available(get_settings().DEVICE):
@@ -233,12 +238,29 @@ def annotate(self, text: str) -> List[Annotation]:
233238
List[Annotation]: A list of annotations containing the extracted named entities.
234239
"""
235240

236-
entities = self._ner_pipeline(text)
241+
if TaggingScheme(self._config.TRAINING_HF_TAGGING_SCHEME.lower()) == TaggingScheme.IOBES:
242+
entities = self._ner_pipeline(text, aggregation_strategy="none")
243+
else:
244+
entities = self._ner_pipeline(text)
237245
df = pd.DataFrame(entities)
238246

239247
if df.empty:
240248
columns = ["label_name", "label_id", "start", "end", "accuracy"]
241249
df = pd.DataFrame(columns=(columns + ["text"]) if self._config.INCLUDE_SPAN_TEXT == "true" else columns)
250+
elif TaggingScheme(self._config.TRAINING_HF_TAGGING_SCHEME.lower()) == TaggingScheme.IOBES:
251+
aggregated_entities = TagProcessor.aggregate_bioes_predictions(
252+
df,
253+
text,
254+
self._config.INCLUDE_SPAN_TEXT == "true",
255+
)
256+
df = pd.DataFrame(aggregated_entities)
257+
if df.empty:
258+
columns = ["label_name", "label_id", "start", "end", "accuracy"]
259+
df = pd.DataFrame(
260+
columns=(columns + ["text"]) if self._config.INCLUDE_SPAN_TEXT == "true" else columns
261+
)
262+
else:
263+
df = df[df["accuracy"] >= self._confidence_threshold]
242264
else:
243265
for idx, row in df.iterrows():
244266
df.loc[idx, "label_id"] = row["entity_group"]

0 commit comments

Comments
 (0)