|
1 | | -# chart_extractor.py |
2 | | -from __future__ import annotations |
3 | | - |
4 | 1 | import re |
5 | 2 | from dataclasses import dataclass |
6 | 3 | from pathlib import Path |
7 | 4 |
|
8 | 5 | import cv2 |
9 | 6 | import numpy as np |
10 | 7 | from huggingface_hub import hf_hub_download |
| 8 | +from rapidocr_onnxruntime import RapidOCR |
11 | 9 | from ultralytics import YOLO |
12 | 10 |
|
13 | | -# ---------- Config ---------- |
14 | | -# Class IDs (must match your training config) |
15 | 11 | CLS_SYMBOL_TITLE = 0 |
16 | 12 | CLS_LAST_PRICE_PILL = 1 |
17 | 13 |
|
18 | | -# Hugging Face model (adjust if you renamed the repo or path) |
19 | 14 | HF_MODEL_REPO = "StephanAkkerman/chart-info-detector" |
20 | | -HF_MODEL_FILE = "weights/best.pt" # path inside the model repo |
| 15 | +HF_MODEL_FILE = "weights/best.pt" |
21 | 16 |
|
22 | | -# OCR engine (lazy-loaded): RapidOCR preferred, Tesseract fallback |
23 | 17 | _OCR = None |
24 | 18 | _OCR_KIND = None # "rapid" | "tesseract" |
25 | 19 |
|
@@ -55,7 +49,6 @@ def _download_weights_if_needed( |
55 | 49 | return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") |
56 | 50 |
|
57 | 51 |
|
58 | | -# --- NEW helpers: classify OCR text, smart swap --- |
59 | 52 | def _looks_like_price(text: str) -> bool: |
60 | 53 | if not text: |
61 | 54 | return False |
@@ -91,26 +84,10 @@ def _ensure_ocr(): |
91 | 84 | global _OCR, _OCR_KIND |
92 | 85 | if _OCR is not None: |
93 | 86 | return _OCR |
94 | | - try: |
95 | | - from rapidocr_onnxruntime import RapidOCR |
96 | | - |
97 | | - _OCR = RapidOCR() |
98 | | - _OCR_KIND = "rapid" |
99 | | - return _OCR |
100 | | - except Exception: |
101 | | - try: |
102 | | - import pytesseract # requires system tesseract (Windows installer / apt-get on Linux) |
103 | 87 |
|
104 | | - _OCR = pytesseract |
105 | | - _OCR_KIND = "tesseract" |
106 | | - return _OCR |
107 | | - except Exception as e2: |
108 | | - raise RuntimeError( |
109 | | - "No OCR engine available. Install one of:\n" |
110 | | - " pip install rapidocr-onnxruntime onnxruntime\n" |
111 | | - "or\n" |
112 | | - " sudo apt-get install tesseract-ocr && pip install pytesseract" |
113 | | - ) from e2 |
| 88 | + _OCR = RapidOCR() |
| 89 | + _OCR_KIND = "rapid" |
| 90 | + return _OCR |
114 | 91 |
|
115 | 92 |
|
116 | 93 | def _read_image(img: str | Path | np.ndarray) -> np.ndarray: |
@@ -175,7 +152,6 @@ def _ocr_text(im_bgr: np.ndarray) -> str: |
175 | 152 | ) |
176 | 153 |
|
177 | 154 |
|
178 | | -# --- REPLACE your title parser with this TradingView-oriented version --- |
179 | 155 | def _parse_title(text: str) -> tuple[str | None, str | None, str | None]: |
180 | 156 | """ |
181 | 157 | Parse (name, exchange, timeframe) from TradingView-style titles like: |
@@ -223,7 +199,6 @@ def _parse_title(text: str) -> tuple[str | None, str | None, str | None]: |
223 | 199 | return name, exchange, timeframe |
224 | 200 |
|
225 | 201 |
|
226 | | -# --- TIGHTER price parser (only from pill text) --- |
227 | 202 | def _parse_pill(text: str) -> tuple[float | None, str | None]: |
228 | 203 | """ |
229 | 204 | Parse (price, session) from pill text; avoid 'S&P 500' false matches. |
@@ -258,7 +233,6 @@ def _parse_pill(text: str) -> tuple[float | None, str | None]: |
258 | 233 | return price, ("regular" if sess is None else sess) |
259 | 234 |
|
260 | 235 |
|
261 | | -# ---------- Core pipeline ---------- |
262 | 236 | class ChartExtractor: |
263 | 237 | """ |
264 | 238 | Detects chart widgets (YOLO) and extracts info via OCR. |
@@ -419,7 +393,6 @@ def analyze( |
419 | 393 | return result |
420 | 394 |
|
421 | 395 |
|
422 | | -# ---------- Quick CLI test ---------- |
423 | 396 | if __name__ == "__main__": |
424 | 397 | import json |
425 | 398 | import sys |
|
0 commit comments