diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/config.py b/app/api/config.py index 48ef59d..52599e4 100644 --- a/app/api/config.py +++ b/app/api/config.py @@ -84,7 +84,7 @@ # LLM configurations MODEL_NAME = os.getenv("MODEL_NAME") LLM_DEFAULT_TEMPERATURE = float(os.getenv("LLM_DEFAULT_TEMPERATURE", 0.0)) -LLM_CHUNK_SIZE = int(os.getenv("LLM_CHUNK_SIZE", 512)) +LLM_CHUNK_SIZE = int(os.getenv("LLM_CHUNK_SIZE", 2000)) LLM_CHUNK_OVERLAP = int(os.getenv("LLM_CHUNK_OVERLAP", 20)) LLM_DISTANCE_THRESHOLD = float(os.getenv("LLM_DISTANCE_THRESHOLD", 0.5)) LLM_MAX_OUTPUT_TOKENS = int(os.getenv("LLM_MAX_OUTPUT_TOKENS", 256)) diff --git a/app/api/llm.py b/app/api/llm.py index eb4404e..a155eeb 100644 --- a/app/api/llm.py +++ b/app/api/llm.py @@ -2,11 +2,11 @@ import openai import json -from langchain.docstore.document import Document as LangChainDocument -from langchain.embeddings.openai import OpenAIEmbeddings +from langchain_core.documents import Document as LangChainDocument +from langchain_openai import OpenAIEmbeddings from fastapi import HTTPException from uuid import UUID, uuid4 -from langchain.text_splitter import ( +from langchain_text_splitters import ( CharacterTextSplitter, MarkdownTextSplitter ) @@ -18,7 +18,7 @@ sanitize_input, sanitize_output ) -from langchain import OpenAI +from langchain_openai import OpenAI from typing import ( List, Union, diff --git a/app/api/tests/__init__.py b/app/api/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/tests/test_llm.py b/app/api/tests/test_llm.py new file mode 100644 index 0000000..ea39f2b --- /dev/null +++ b/app/api/tests/test_llm.py @@ -0,0 +1,83 @@ + +import sys +import os +import pytest +from sqlmodel import Session, SQLModel, create_engine +# Add the app directory to the sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from llm import chat_query, get_embeddings +from models import Organization, Project, Document, Node +from config import DATABASE_URL + + +# Create a test database +engine = create_engine(DATABASE_URL, echo=False) + +def create_db_and_tables(): + SQLModel.metadata.create_all(engine) + +@pytest.fixture(name="session") +def session_fixture(): + create_db_and_tables() + with Session(engine) as session: + yield session + SQLModel.metadata.drop_all(engine) + + +def test_kekzal_side_effects_bug(session: Session): + """ + This test reproduces the bug where the chatbot fails to answer a question + about the side effects of Kekzal. + """ + # 1. Create dummy organization and project + org = Organization(display_name="Test Org", namespace="test-org") + session.add(org) + session.commit() + session.refresh(org) + + project = Project(display_name="Test Project", organization_id=org.id) + session.add(project) + session.commit() + session.refresh(project) + + # 2. Load the Kekzal document + doc_path = "app/api/data/training_data/project-kekzal.md" + with open(doc_path, "r") as f: + doc_content = f.read() + + doc = Document( + display_name="project-kekzal.md", + project_id=project.id, + organization_id=org.id, + data=doc_content, + hash="testhash", + version=1, + ) + session.add(doc) + session.commit() + session.refresh(doc) + + # 3. Create embeddings for the document + arr_documents, embeddings = get_embeddings(doc_content) + for i, (doc_chunk, embedding) in enumerate(zip(arr_documents, embeddings)): + node = Node( + document_id=doc.id, + text=doc_chunk, + embeddings=embedding, + node_order=i, + ) + session.add(node) + session.commit() + + + # 4. Ask the question about side effects + response = chat_query( + query_str="What are the side effects of Kekzal?", + session=session, + project=project, + organization=org, + ) + + # 5. Assert that the response is correct + correct_response = "Some potential side effects of Kekzal may include" + assert correct_response in response.response