From 08c9605bc0c3117bab5b14b2f53e6a528b47ad32 Mon Sep 17 00:00:00 2001 From: arifal Date: Mon, 25 Aug 2025 15:19:04 +0700 Subject: [PATCH] add common response and prisma service and icd vector service and icd api --- package-lock.json | 43 ++ package.json | 1 + src/app.controller.spec.ts | 22 - src/app.controller.ts | 12 - src/app.module.ts | 12 +- src/app.service.ts | 8 - src/common/prisma/prisma.module.ts | 8 + .../prisma/prisma/prisma.service.spec.ts | 18 + src/common/prisma/prisma/prisma.service.ts | 16 + src/common/response/response.module.ts | 8 + .../response/response.interceptor.spec.ts | 7 + .../response/response/response.interceptor.ts | 26 ++ .../response/response.service.spec.ts | 18 + .../response/response/response.service.ts | 38 ++ src/icd-code-vector/icd-code-vector.module.ts | 10 + .../icd-code-vector.service.spec.ts | 18 + .../icd-code-vector.service.ts | 424 ++++++++++++++++++ src/icd-code/icd-code.module.ts | 13 + .../icd-code/icd-code.controller.spec.ts | 18 + src/icd-code/icd-code/icd-code.controller.ts | 419 +++++++++++++++++ .../icd-code/icd-code.service.spec.ts | 18 + src/icd-code/icd-code/icd-code.service.ts | 37 ++ src/icd/icd.controller.ts | 40 +- src/icd/icd.module.ts | 4 +- src/icd/icd.service.ts | 8 +- 25 files changed, 1171 insertions(+), 75 deletions(-) delete mode 100644 src/app.controller.spec.ts delete mode 100644 src/app.controller.ts delete mode 100644 src/app.service.ts create mode 100644 src/common/prisma/prisma.module.ts create mode 100644 src/common/prisma/prisma/prisma.service.spec.ts create mode 100644 src/common/prisma/prisma/prisma.service.ts create mode 100644 src/common/response/response.module.ts create mode 100644 src/common/response/response/response.interceptor.spec.ts create mode 100644 src/common/response/response/response.interceptor.ts create mode 100644 src/common/response/response/response.service.spec.ts create mode 100644 src/common/response/response/response.service.ts create mode 100644 src/icd-code-vector/icd-code-vector.module.ts create mode 100644 src/icd-code-vector/icd-code-vector/icd-code-vector.service.spec.ts create mode 100644 src/icd-code-vector/icd-code-vector/icd-code-vector.service.ts create mode 100644 src/icd-code/icd-code.module.ts create mode 100644 src/icd-code/icd-code/icd-code.controller.spec.ts create mode 100644 src/icd-code/icd-code/icd-code.controller.ts create mode 100644 src/icd-code/icd-code/icd-code.service.spec.ts create mode 100644 src/icd-code/icd-code/icd-code.service.ts diff --git a/package-lock.json b/package-lock.json index 38547d8..7735373 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@langchain/community": "^0.3.53", "@langchain/openai": "^0.6.9", "@nestjs/common": "^11.0.1", + "@nestjs/config": "^4.0.2", "@nestjs/core": "^11.0.1", "@nestjs/platform-express": "^11.0.1", "@nestjs/swagger": "^11.2.0", @@ -3168,6 +3169,33 @@ } } }, + "node_modules/@nestjs/config": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/@nestjs/config/-/config-4.0.2.tgz", + "integrity": "sha512-McMW6EXtpc8+CwTUwFdg6h7dYcBUpH5iUILCclAsa+MbCEvC9ZKu4dCHRlJqALuhjLw97pbQu62l4+wRwGeZqA==", + "license": "MIT", + "dependencies": { + "dotenv": "16.4.7", + "dotenv-expand": "12.0.1", + "lodash": "4.17.21" + }, + "peerDependencies": { + "@nestjs/common": "^10.0.0 || ^11.0.0", + "rxjs": "^7.1.0" + } + }, + "node_modules/@nestjs/config/node_modules/dotenv": { + "version": "16.4.7", + "resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.4.7.tgz", + "integrity": "sha512-47qPchRCykZC03FhkYAhrvwU4xDBFIj1QPqaarj6mdM/hgUzfPHcpkHJOn3mJAufFeeAxAzeGsr5X0M4k6fLZQ==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } + }, "node_modules/@nestjs/core": { "version": "11.1.6", "resolved": "https://registry.npmjs.org/@nestjs/core/-/core-11.1.6.tgz", @@ -6159,6 +6187,21 @@ "url": "https://dotenvx.com" } }, + "node_modules/dotenv-expand": { + "version": "12.0.1", + "resolved": "https://registry.npmjs.org/dotenv-expand/-/dotenv-expand-12.0.1.tgz", + "integrity": "sha512-LaKRbou8gt0RNID/9RoI+J2rvXsBRPMV7p+ElHlPhcSARbCPDYcYG2s1TIzAfWv4YSgyY5taidWzzs31lNV3yQ==", + "license": "BSD-2-Clause", + "dependencies": { + "dotenv": "^16.4.5" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://dotenvx.com" + } + }, "node_modules/dunder-proto": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", diff --git a/package.json b/package.json index ab737f1..8132bd0 100644 --- a/package.json +++ b/package.json @@ -25,6 +25,7 @@ "@langchain/community": "^0.3.53", "@langchain/openai": "^0.6.9", "@nestjs/common": "^11.0.1", + "@nestjs/config": "^4.0.2", "@nestjs/core": "^11.0.1", "@nestjs/platform-express": "^11.0.1", "@nestjs/swagger": "^11.2.0", diff --git a/src/app.controller.spec.ts b/src/app.controller.spec.ts deleted file mode 100644 index d22f389..0000000 --- a/src/app.controller.spec.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { Test, TestingModule } from '@nestjs/testing'; -import { AppController } from './app.controller'; -import { AppService } from './app.service'; - -describe('AppController', () => { - let appController: AppController; - - beforeEach(async () => { - const app: TestingModule = await Test.createTestingModule({ - controllers: [AppController], - providers: [AppService], - }).compile(); - - appController = app.get(AppController); - }); - - describe('root', () => { - it('should return "Hello World!"', () => { - expect(appController.getHello()).toBe('Hello World!'); - }); - }); -}); diff --git a/src/app.controller.ts b/src/app.controller.ts deleted file mode 100644 index cce879e..0000000 --- a/src/app.controller.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { Controller, Get } from '@nestjs/common'; -import { AppService } from './app.service'; - -@Controller() -export class AppController { - constructor(private readonly appService: AppService) {} - - @Get() - getHello(): string { - return this.appService.getHello(); - } -} diff --git a/src/app.module.ts b/src/app.module.ts index 827aa67..c9f78b6 100644 --- a/src/app.module.ts +++ b/src/app.module.ts @@ -1,12 +1,14 @@ import { Module } from '@nestjs/common'; -import { AppController } from './app.controller'; -import { AppService } from './app.service'; import { IcdModule } from './icd/icd.module'; import { HealthModule } from './health/health.module'; +import { ResponseModule } from './common/response/response.module'; +import { PrismaModule } from './common/prisma/prisma.module'; +import { IcdCodeModule } from './icd-code/icd-code.module'; +import { IcdCodeVectorModule } from './icd-code-vector/icd-code-vector.module'; @Module({ - imports: [IcdModule, HealthModule], - controllers: [AppController], - providers: [AppService], + imports: [IcdModule, HealthModule, ResponseModule, PrismaModule, IcdCodeModule, IcdCodeVectorModule], + controllers: [], + providers: [], }) export class AppModule {} diff --git a/src/app.service.ts b/src/app.service.ts deleted file mode 100644 index 927d7cc..0000000 --- a/src/app.service.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { Injectable } from '@nestjs/common'; - -@Injectable() -export class AppService { - getHello(): string { - return 'Hello World!'; - } -} diff --git a/src/common/prisma/prisma.module.ts b/src/common/prisma/prisma.module.ts new file mode 100644 index 0000000..4a5e644 --- /dev/null +++ b/src/common/prisma/prisma.module.ts @@ -0,0 +1,8 @@ +import { Module } from '@nestjs/common'; +import { PrismaService } from './prisma/prisma.service'; + +@Module({ + providers: [PrismaService], + exports: [PrismaService], +}) +export class PrismaModule {} diff --git a/src/common/prisma/prisma/prisma.service.spec.ts b/src/common/prisma/prisma/prisma.service.spec.ts new file mode 100644 index 0000000..a68cb9e --- /dev/null +++ b/src/common/prisma/prisma/prisma.service.spec.ts @@ -0,0 +1,18 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { PrismaService } from './prisma.service'; + +describe('PrismaService', () => { + let service: PrismaService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [PrismaService], + }).compile(); + + service = module.get(PrismaService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); +}); diff --git a/src/common/prisma/prisma/prisma.service.ts b/src/common/prisma/prisma/prisma.service.ts new file mode 100644 index 0000000..ba00c9f --- /dev/null +++ b/src/common/prisma/prisma/prisma.service.ts @@ -0,0 +1,16 @@ +import { Injectable, OnModuleDestroy, OnModuleInit } from '@nestjs/common'; +import { PrismaClient } from '@prisma/client'; + +@Injectable() +export class PrismaService + extends PrismaClient + implements OnModuleInit, OnModuleDestroy +{ + async onModuleInit() { + await this.$connect(); + } + + async onModuleDestroy() { + await this.$disconnect(); + } +} diff --git a/src/common/response/response.module.ts b/src/common/response/response.module.ts new file mode 100644 index 0000000..c710986 --- /dev/null +++ b/src/common/response/response.module.ts @@ -0,0 +1,8 @@ +import { Module } from '@nestjs/common'; +import { ResponseService } from './response/response.service'; + +@Module({ + providers: [ResponseService], + exports: [ResponseService], +}) +export class ResponseModule {} diff --git a/src/common/response/response/response.interceptor.spec.ts b/src/common/response/response/response.interceptor.spec.ts new file mode 100644 index 0000000..80aef8f --- /dev/null +++ b/src/common/response/response/response.interceptor.spec.ts @@ -0,0 +1,7 @@ +import { ResponseInterceptor } from './response.interceptor'; + +describe('ResponseInterceptor', () => { + it('should be defined', () => { + expect(new ResponseInterceptor()).toBeDefined(); + }); +}); diff --git a/src/common/response/response/response.interceptor.ts b/src/common/response/response/response.interceptor.ts new file mode 100644 index 0000000..43171f0 --- /dev/null +++ b/src/common/response/response/response.interceptor.ts @@ -0,0 +1,26 @@ +import { + CallHandler, + ExecutionContext, + Injectable, + NestInterceptor, +} from '@nestjs/common'; +import { Observable, map } from 'rxjs'; + +@Injectable() +export class ResponseInterceptor implements NestInterceptor { + intercept(context: ExecutionContext, next: CallHandler): Observable { + return next.handle().pipe( + map((data) => { + if (data?.success !== undefined) { + return data; + } + + return { + status: true, + data, + message: 'Success', + }; + }), + ); + } +} diff --git a/src/common/response/response/response.service.spec.ts b/src/common/response/response/response.service.spec.ts new file mode 100644 index 0000000..ffaac60 --- /dev/null +++ b/src/common/response/response/response.service.spec.ts @@ -0,0 +1,18 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { ResponseService } from './response.service'; + +describe('ResponseService', () => { + let service: ResponseService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ResponseService], + }).compile(); + + service = module.get(ResponseService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); +}); diff --git a/src/common/response/response/response.service.ts b/src/common/response/response/response.service.ts new file mode 100644 index 0000000..0af9991 --- /dev/null +++ b/src/common/response/response/response.service.ts @@ -0,0 +1,38 @@ +import { Injectable } from '@nestjs/common'; + +@Injectable() +export class ResponseService { + success(data: T, message: string = 'Success') { + return { + status: true, + data, + message, + }; + } + + paginate(data: T[], total: number, page: number, pageSize: number) { + return { + status: true, + data, + total, + page, + pageSize, + }; + } + + ids(ids: (string | number)[], message: string = 'Success') { + return { + status: true, + data: ids, + message, + }; + } + + error(message: string = 'Error', statusCode: number = 400) { + return { + status: false, + message, + statusCode, + }; + } +} diff --git a/src/icd-code-vector/icd-code-vector.module.ts b/src/icd-code-vector/icd-code-vector.module.ts new file mode 100644 index 0000000..f8ef1b8 --- /dev/null +++ b/src/icd-code-vector/icd-code-vector.module.ts @@ -0,0 +1,10 @@ +import { Module } from '@nestjs/common'; +import { ConfigModule } from '@nestjs/config'; +import { IcdCodeVectorService } from './icd-code-vector/icd-code-vector.service'; + +@Module({ + imports: [ConfigModule], + providers: [IcdCodeVectorService], + exports: [IcdCodeVectorService], +}) +export class IcdCodeVectorModule {} diff --git a/src/icd-code-vector/icd-code-vector/icd-code-vector.service.spec.ts b/src/icd-code-vector/icd-code-vector/icd-code-vector.service.spec.ts new file mode 100644 index 0000000..c44cc22 --- /dev/null +++ b/src/icd-code-vector/icd-code-vector/icd-code-vector.service.spec.ts @@ -0,0 +1,18 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { IcdCodeVectorService } from './icd-code-vector.service'; + +describe('IcdCodeVectorService', () => { + let service: IcdCodeVectorService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [IcdCodeVectorService], + }).compile(); + + service = module.get(IcdCodeVectorService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); +}); diff --git a/src/icd-code-vector/icd-code-vector/icd-code-vector.service.ts b/src/icd-code-vector/icd-code-vector/icd-code-vector.service.ts new file mode 100644 index 0000000..1d5236e --- /dev/null +++ b/src/icd-code-vector/icd-code-vector/icd-code-vector.service.ts @@ -0,0 +1,424 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { OpenAIEmbeddings } from '@langchain/openai'; +import { Pool } from 'pg'; + +export interface VectorSearchResult { + id: string; + code: string; + display: string; + version: string; + category: string; + similarity: number; +} + +export interface EmbeddingGenerationResult { + processed: number; + errors: number; + totalSample: number; +} + +@Injectable() +export class IcdCodeVectorService { + private readonly logger = new Logger(IcdCodeVectorService.name); + private readonly pool: Pool; + private embeddings: OpenAIEmbeddings | null = null; + + constructor(private readonly configService: ConfigService) { + // Initialize PostgreSQL connection pool + this.pool = new Pool({ + connectionString: this.configService.get('DATABASE_URL'), + max: 20, + idleTimeoutMillis: 30000, + connectionTimeoutMillis: 2000, + }); + + this.initializeEmbeddings(); + } + + /** + * Initialize OpenAI embeddings + */ + private async initializeEmbeddings() { + try { + const apiKey = this.configService.get('OPENAI_API_KEY'); + if (!apiKey) { + this.logger.error( + 'OPENAI_API_KEY not found. Vector operations require OpenAI API key.', + ); + throw new Error('OPENAI_API_KEY is required for vector operations'); + } + + const apiModel = this.configService.get('OPENAI_API_MODEL'); + const modelName = apiModel || 'text-embedding-ada-002'; + + this.embeddings = new OpenAIEmbeddings({ + openAIApiKey: apiKey, + modelName: modelName, + maxConcurrency: 5, + }); + + this.logger.log( + `OpenAI embeddings initialized successfully with model: ${modelName}`, + ); + } catch (error) { + this.logger.error('Failed to initialize OpenAI embeddings:', error); + throw new Error( + `Failed to initialize OpenAI embeddings: ${error.message}`, + ); + } + } + + /** + * Generate embedding untuk text menggunakan OpenAI + */ + async generateEmbedding(text: string): Promise { + try { + this.logger.log( + `Generating embedding for text: ${text.substring(0, 100)}...`, + ); + + if (!this.embeddings) { + throw new Error( + 'OpenAI embeddings not initialized. Please check your API configuration.', + ); + } + + // Use OpenAI embeddings + const embedding = await this.embeddings.embedQuery(text); + this.logger.log( + `Generated OpenAI embedding with ${embedding.length} dimensions`, + ); + return embedding; + } catch (error) { + this.logger.error('Error generating embedding:', error); + throw new Error(`Failed to generate embedding: ${error.message}`); + } + } + + /** + * Generate dan simpan embeddings untuk ICD codes (default: 100) + */ + async generateAndStoreEmbeddings( + limit: number = 100, + ): Promise { + try { + this.logger.log( + `Starting batch embedding generation and storage for ${limit} ICD codes...`, + ); + + // Get ICD codes without embeddings using raw SQL + const codesWithoutEmbedding = await this.pool.query( + 'SELECT id, code, display, version, category FROM icd_codes WHERE embedding IS NULL LIMIT $1', + [limit], + ); + + if (codesWithoutEmbedding.rows.length === 0) { + this.logger.log('All ICD codes already have embeddings'); + return { processed: 0, errors: 0, totalSample: 0 }; + } + + this.logger.log( + `Found ${codesWithoutEmbedding.rows.length} codes without embeddings (limited to ${limit})`, + ); + + let processed = 0; + let errors = 0; + + // Process each code + for (let i = 0; i < codesWithoutEmbedding.rows.length; i++) { + const code = codesWithoutEmbedding.rows[i]; + try { + // Create text representation for embedding + const text = `${code.code} - ${code.display}`; + + // Generate embedding + const embedding = await this.generateEmbedding(text); + + // Convert embedding array to proper vector format for pgvector + const vectorString = `[${embedding.join(',')}]`; + + // Update database with embedding, metadata, and content using raw SQL + await this.pool.query( + `UPDATE icd_codes + SET embedding = $1::vector, + metadata = $2::jsonb, + content = $3, + "updatedAt" = NOW() + WHERE id = $4`, + [ + vectorString, + JSON.stringify({ + id: code.id, + code: code.code, + display: code.display, + version: code.version, + category: code.category, + }), + text, + code.id, + ], + ); + + processed++; + + if (processed % 10 === 0) { + this.logger.log( + `Processed ${processed}/${codesWithoutEmbedding.rows.length} embeddings`, + ); + } + } catch (error) { + this.logger.error(`Error processing code ${code.code}:`, error); + errors++; + } + } + + this.logger.log( + `Embedding generation and storage completed. Processed: ${processed}, Errors: ${errors}, Total: ${codesWithoutEmbedding.rows.length}`, + ); + return { + processed, + errors, + totalSample: codesWithoutEmbedding.rows.length, + }; + } catch (error) { + this.logger.error('Error in generateAndStoreEmbeddings:', error); + throw error; + } + } + + /** + * Generate dan simpan embeddings untuk ICD codes dengan kategori tertentu + */ + async generateAndStoreEmbeddingsByCategory( + category: string, + limit: number = 100, + ): Promise { + try { + this.logger.log( + `Starting batch embedding generation for ${limit} ICD codes in category: ${category}`, + ); + + // Get ICD codes by category without embeddings using raw SQL + const codesWithoutEmbedding = await this.pool.query( + 'SELECT id, code, display, version, category FROM icd_codes WHERE embedding IS NULL AND category = $1 LIMIT $2', + [category, limit], + ); + + if (codesWithoutEmbedding.rows.length === 0) { + this.logger.log( + `No ICD codes found in category '${category}' without embeddings`, + ); + return { processed: 0, errors: 0, totalSample: 0, category }; + } + + this.logger.log( + `Found ${codesWithoutEmbedding.rows.length} codes in category '${category}' without embeddings (limited to ${limit})`, + ); + + let processed = 0; + let errors = 0; + + // Process each code + for (let i = 0; i < codesWithoutEmbedding.rows.length; i++) { + const code = codesWithoutEmbedding.rows[i]; + try { + // Create text representation for embedding + const text = `${code.code} - ${code.display}`; + + // Generate embedding + const embedding = await this.generateEmbedding(text); + + // Convert embedding array to proper vector format for pgvector + const vectorString = `[${embedding.join(',')}]`; + + // Update database with embedding, metadata, and content using raw SQL + await this.pool.query( + `UPDATE icd_codes + SET embedding = $1::vector, + metadata = $2::jsonb, + content = $3, + "updatedAt" = NOW() + WHERE id = $4`, + [ + vectorString, + JSON.stringify({ + id: code.id, + code: code.code, + display: code.display, + version: code.version, + category: code.category, + }), + text, + code.id, + ], + ); + + processed++; + + if (processed % 10 === 0) { + this.logger.log( + `Processed ${processed}/${codesWithoutEmbedding.rows.length} embeddings in category '${category}'`, + ); + } + } catch (error) { + this.logger.error(`Error processing code ${code.code}:`, error); + errors++; + } + } + + this.logger.log( + `Embedding generation completed for category '${category}'. Processed: ${processed}, Errors: ${errors}, Total: ${codesWithoutEmbedding.rows.length}`, + ); + return { + processed, + errors, + totalSample: codesWithoutEmbedding.rows.length, + category, + }; + } catch (error) { + this.logger.error( + `Error in generateAndStoreEmbeddingsByCategory for category '${category}':`, + error, + ); + throw error; + } + } + + /** + * Vector similarity search menggunakan pgvector dengan threshold dari config + */ + async search( + query: string, + category?: string, + limit: number = 10, + ): Promise { + try { + // Get threshold from config service + const threshold = this.configService.get('THRESHOLD', 0.85); + + this.logger.log( + `Performing vector search for: "${query}" with threshold: ${threshold}${category ? `, category: ${category}` : ''}`, + ); + + if (!this.embeddings) { + throw new Error('OpenAI embeddings not initialized'); + } + + // Generate embedding for query + const queryEmbedding = await this.generateEmbedding(query); + + // Convert embedding array to proper vector format for pgvector + const vectorString = `[${queryEmbedding.join(',')}]`; + + // Build SQL query for vector similarity search + let sql = ` + SELECT + id, code, display, version, category, + (1 - (embedding <=> $1::vector)) as similarity + FROM icd_codes + WHERE embedding IS NOT NULL + AND (1 - (embedding <=> $1::vector)) >= $2 + `; + + const params: any[] = [vectorString, threshold]; + let paramIndex = 3; + + if (category) { + sql += ` AND category = $${paramIndex}`; + params.push(category); + paramIndex++; + } + + // Order by similarity descending and limit results + sql += ` ORDER BY similarity DESC LIMIT $${paramIndex}`; + params.push(limit); + + // Execute raw SQL query + const result = await this.pool.query(sql, params); + + // Transform results + const searchResults: VectorSearchResult[] = result.rows.map( + (row: any) => ({ + id: row.id, + code: row.code, + display: row.display, + version: row.version, + category: row.category, + similarity: parseFloat(row.similarity), + }), + ); + + this.logger.log( + `Vector search returned ${searchResults.length} results for query: "${query}" with threshold: ${threshold}`, + ); + return searchResults; + } catch (error) { + this.logger.error('Error in vector search:', error); + throw error; + } + } + + /** + * Search dengan kategori spesifik (ICD-9 atau ICD-10) + */ + async searchByCategory( + query: string, + category: 'ICD9' | 'ICD10', + limit: number = 10, + ): Promise { + return this.search(query, category, limit); + } + + /** + * Get embedding statistics + */ + async getEmbeddingStats(): Promise<{ + total: number; + withEmbeddings: number; + withoutEmbeddings: number; + percentage: number; + threshold: number; + }> { + try { + // Use raw SQL to get embedding statistics + const [totalResult, withEmbeddingsResult] = await Promise.all([ + this.pool.query('SELECT COUNT(*) as count FROM icd_codes'), + this.pool.query( + 'SELECT COUNT(*) as count FROM icd_codes WHERE embedding IS NOT NULL', + ), + ]); + + const total = parseInt(totalResult.rows[0].count); + const withEmbeddings = parseInt(withEmbeddingsResult.rows[0].count); + const withoutEmbeddings = total - withEmbeddings; + const percentage = total > 0 ? (withEmbeddings / total) * 100 : 0; + const threshold = this.configService.get('THRESHOLD', 0.85); + + return { + total, + withEmbeddings, + withoutEmbeddings, + percentage: Math.round(percentage * 100) / 100, + threshold, + }; + } catch (error) { + this.logger.error('Error getting embedding stats:', error); + throw error; + } + } + + /** + * Get current threshold configuration + */ + getThreshold(): number { + return this.configService.get('THRESHOLD', 0.85); + } + + /** + * Cleanup resources + */ + async onModuleDestroy() { + await this.pool.end(); + } +} diff --git a/src/icd-code/icd-code.module.ts b/src/icd-code/icd-code.module.ts new file mode 100644 index 0000000..1c668ff --- /dev/null +++ b/src/icd-code/icd-code.module.ts @@ -0,0 +1,13 @@ +import { Module } from '@nestjs/common'; +import { IcdCodeService } from './icd-code/icd-code.service'; +import { IcdCodeController } from './icd-code/icd-code.controller'; +import { PrismaModule } from 'src/common/prisma/prisma.module'; +import { ResponseModule } from 'src/common/response/response.module'; +import { IcdCodeVectorModule } from '../icd-code-vector/icd-code-vector.module'; + +@Module({ + imports: [PrismaModule, ResponseModule, IcdCodeVectorModule], + providers: [IcdCodeService], + controllers: [IcdCodeController], +}) +export class IcdCodeModule {} diff --git a/src/icd-code/icd-code/icd-code.controller.spec.ts b/src/icd-code/icd-code/icd-code.controller.spec.ts new file mode 100644 index 0000000..419ae1c --- /dev/null +++ b/src/icd-code/icd-code/icd-code.controller.spec.ts @@ -0,0 +1,18 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { IcdCodeController } from './icd-code.controller'; + +describe('IcdCodeController', () => { + let controller: IcdCodeController; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + controllers: [IcdCodeController], + }).compile(); + + controller = module.get(IcdCodeController); + }); + + it('should be defined', () => { + expect(controller).toBeDefined(); + }); +}); diff --git a/src/icd-code/icd-code/icd-code.controller.ts b/src/icd-code/icd-code/icd-code.controller.ts new file mode 100644 index 0000000..daba2c3 --- /dev/null +++ b/src/icd-code/icd-code/icd-code.controller.ts @@ -0,0 +1,419 @@ +import { + Controller, + Get, + Post, + Query, + Body, + HttpStatus, + HttpCode, +} from '@nestjs/common'; +import { IcdCodeService } from './icd-code.service'; +import { IcdCodeVectorService } from '../../icd-code-vector/icd-code-vector/icd-code-vector.service'; +import { + ApiTags, + ApiOperation, + ApiQuery, + ApiBody, + ApiResponse, + ApiBadRequestResponse, + ApiInternalServerErrorResponse, +} from '@nestjs/swagger'; +import { ResponseService } from 'src/common/response/response/response.service'; + +export class GenerateEmbeddingsRequestDto { + limit?: number; +} + +export class GenerateEmbeddingsByCategoryRequestDto { + category: 'ICD9' | 'ICD10'; + limit?: number; +} + +export class VectorSearchRequestDto { + query: string; + category?: 'ICD9' | 'ICD10'; + limit?: number; +} + +export class GenerateEmbeddingsResponseDto { + message: string; + processed: number; + errors: number; + totalSample: number; +} + +export class VectorSearchResponseDto { + results: any[]; + total: number; + query: string; + category?: string; + threshold: number; +} + +@ApiTags('ICD Code') +@Controller('icd-code') +export class IcdCodeController { + constructor( + private readonly icdCodeService: IcdCodeService, + private readonly responseService: ResponseService, + private readonly icdCodeVectorService: IcdCodeVectorService, + ) {} + + @Get('data') + @ApiOperation({ + summary: 'Search ICD codes with filters and pagination', + description: + 'Search for ICD codes using various filters like category, search term, with pagination support. Returns a paginated list of matching ICD codes.', + }) + @ApiQuery({ + name: 'search', + required: false, + description: 'Search term for ICD code or description', + example: 'diabetes', + }) + @ApiQuery({ + name: 'page', + required: false, + description: 'Page number for pagination', + example: 1, + }) + @ApiQuery({ + name: 'limit', + required: false, + description: 'Number of items per page', + example: 10, + }) + async findIcdCodes( + @Query('search') search: string, + @Query('page') page: string, + @Query('limit') limit: string, + ) { + try { + const result = await this.icdCodeService.findIcdCodes( + search, + Number(page), + Number(limit), + ); + return this.responseService.paginate( + result.data, + result.page, + result.limit, + result.total, + ); + } catch (error) { + return this.responseService.error('Internal server error during search'); + } + } + + @Post('generate-embeddings') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Generate and store embeddings for ICD codes', + description: + 'Batch generate embeddings for ICD codes and store them in the database with pgvector. This process may take some time depending on the number of codes.', + }) + @ApiBody({ type: GenerateEmbeddingsRequestDto }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Embedding generation and storage results summary', + type: GenerateEmbeddingsResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid request parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during embedding generation', + }) + async generateAndStoreEmbeddings( + @Body() body: GenerateEmbeddingsRequestDto, + ): Promise { + try { + const result = await this.icdCodeVectorService.generateAndStoreEmbeddings( + body.limit, + ); + return { + message: `Processed ${result.processed} embeddings with ${result.errors} errors`, + ...result, + }; + } catch (error) { + throw new Error(`Failed to generate embeddings: ${error.message}`); + } + } + + @Post('generate-embeddings-by-category') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Generate and store embeddings for ICD codes by category', + description: + 'Batch generate embeddings for ICD codes in a specific category (ICD9 or ICD10) and store them in the database.', + }) + @ApiBody({ type: GenerateEmbeddingsByCategoryRequestDto }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Embedding generation and storage results summary by category', + type: GenerateEmbeddingsResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid request parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during embedding generation', + }) + async generateAndStoreEmbeddingsByCategory( + @Body() body: GenerateEmbeddingsByCategoryRequestDto, + ): Promise { + try { + const result = + await this.icdCodeVectorService.generateAndStoreEmbeddingsByCategory( + body.category, + body.limit, + ); + return { + message: `Processed ${result.processed} embeddings for category ${result.category} with ${result.errors} errors`, + ...result, + }; + } catch (error) { + throw new Error( + `Failed to generate embeddings for category ${body.category}: ${error.message}`, + ); + } + } + + @Post('vector-search') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Search ICD codes using vector similarity', + description: + 'Search for ICD codes using vector similarity with configurable threshold and optional category filtering.', + }) + @ApiBody({ type: VectorSearchRequestDto }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Search results with vector similarity scores', + type: VectorSearchResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid search parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during search', + }) + async vectorSearch( + @Body() body: VectorSearchRequestDto, + ): Promise { + try { + const results = await this.icdCodeVectorService.search( + body.query, + body.category, + body.limit, + ); + + return { + results, + total: results.length, + query: body.query, + category: body.category, + threshold: this.icdCodeVectorService.getThreshold(), + }; + } catch (error) { + throw new Error(`Vector search failed: ${error.message}`); + } + } + + @Get('vector-search') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Search ICD codes using vector similarity (GET method)', + description: + 'Search for ICD codes using vector similarity with query parameters for easier testing.', + }) + @ApiQuery({ + name: 'query', + description: 'Text query to search for', + example: 'diabetes', + }) + @ApiQuery({ + name: 'category', + description: 'ICD category filter', + required: false, + enum: ['ICD9', 'ICD10'], + example: 'ICD10', + }) + @ApiQuery({ + name: 'limit', + description: 'Maximum number of results', + required: false, + type: Number, + example: 10, + }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Search results with vector similarity scores', + type: VectorSearchResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid search parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during search', + }) + async vectorSearchGet( + @Query('query') query: string, + @Query('category') category?: 'ICD9' | 'ICD10', + @Query('limit') limit?: string, + ): Promise { + try { + const limitNumber = limit ? parseInt(limit) : 10; + const results = await this.icdCodeVectorService.search( + query, + category, + limitNumber, + ); + + return { + results, + total: results.length, + query, + category, + threshold: this.icdCodeVectorService.getThreshold(), + }; + } catch (error) { + throw new Error(`Vector search failed: ${error.message}`); + } + } + + @Get('vector-search/icd9') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Search ICD-9 codes using vector similarity', + description: + 'Search for ICD-9 codes using vector similarity with configurable threshold.', + }) + @ApiQuery({ + name: 'query', + description: 'Text query to search for', + example: 'cardiac procedure', + }) + @ApiQuery({ + name: 'limit', + description: 'Maximum number of results', + required: false, + type: Number, + example: 10, + }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'ICD-9 search results with vector similarity scores', + type: VectorSearchResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid search parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during search', + }) + async vectorSearchICD9( + @Query('query') query: string, + @Query('limit') limit?: string, + ): Promise { + try { + const limitNumber = limit ? parseInt(limit) : 10; + const results = await this.icdCodeVectorService.searchByCategory( + query, + 'ICD9', + limitNumber, + ); + + return { + results, + total: results.length, + query, + category: 'ICD9', + threshold: this.icdCodeVectorService.getThreshold(), + }; + } catch (error) { + throw new Error(`ICD-9 vector search failed: ${error.message}`); + } + } + + @Get('vector-search/icd10') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Search ICD-10 codes using vector similarity', + description: + 'Search for ICD-10 codes using vector similarity with configurable threshold.', + }) + @ApiQuery({ + name: 'query', + description: 'Text query to search for', + example: 'diabetes mellitus', + }) + @ApiQuery({ + name: 'limit', + description: 'Maximum number of results', + required: false, + type: Number, + example: 10, + }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'ICD-10 search results with vector similarity scores', + type: VectorSearchResponseDto, + }) + @ApiBadRequestResponse({ description: 'Invalid search parameters' }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error during search', + }) + async vectorSearchICD10( + @Query('query') query: string, + @Query('limit') limit?: string, + ): Promise { + try { + const limitNumber = limit ? parseInt(limit) : 10; + const results = await this.icdCodeVectorService.searchByCategory( + query, + 'ICD10', + limitNumber, + ); + + return { + results, + total: results.length, + query, + category: 'ICD10', + threshold: this.icdCodeVectorService.getThreshold(), + }; + } catch (error) { + throw new Error(`ICD-10 vector search failed: ${error.message}`); + } + } + + @Get('embedding-stats') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Get embedding statistics', + description: 'Get statistics about ICD codes and their embedding status.', + }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Embedding statistics and current threshold', + }) + @ApiInternalServerErrorResponse({ + description: 'Internal server error getting statistics', + }) + async getEmbeddingStats() { + try { + return await this.icdCodeVectorService.getEmbeddingStats(); + } catch (error) { + throw new Error(`Failed to get embedding stats: ${error.message}`); + } + } + + @Get('threshold') + @HttpCode(HttpStatus.OK) + @ApiOperation({ + summary: 'Get current similarity threshold', + description: 'Get the current similarity threshold used for vector search.', + }) + @ApiResponse({ + status: HttpStatus.OK, + description: 'Current similarity threshold value', + }) + async getThreshold() { + try { + return { threshold: this.icdCodeVectorService.getThreshold() }; + } catch (error) { + throw new Error(`Failed to get threshold: ${error.message}`); + } + } +} diff --git a/src/icd-code/icd-code/icd-code.service.spec.ts b/src/icd-code/icd-code/icd-code.service.spec.ts new file mode 100644 index 0000000..4ee0b54 --- /dev/null +++ b/src/icd-code/icd-code/icd-code.service.spec.ts @@ -0,0 +1,18 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { IcdCodeService } from './icd-code.service'; + +describe('IcdCodeService', () => { + let service: IcdCodeService; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [IcdCodeService], + }).compile(); + + service = module.get(IcdCodeService); + }); + + it('should be defined', () => { + expect(service).toBeDefined(); + }); +}); diff --git a/src/icd-code/icd-code/icd-code.service.ts b/src/icd-code/icd-code/icd-code.service.ts new file mode 100644 index 0000000..4e6d4fe --- /dev/null +++ b/src/icd-code/icd-code/icd-code.service.ts @@ -0,0 +1,37 @@ +import { Injectable } from '@nestjs/common'; +import { PrismaService } from 'src/common/prisma/prisma/prisma.service'; + +@Injectable() +export class IcdCodeService { + constructor(private readonly prisma: PrismaService) {} + + async findIcdCodes(search: string, page: number, limit: number) { + const where: any = {}; + + if (search) { + where.OR = [ + { code: { contains: search, mode: 'insensitive' } }, + { display: { contains: search, mode: 'insensitive' } }, + ]; + } + + const skip = (page - 1) * limit; + + const [data, total] = await Promise.all([ + this.prisma.icdCode.findMany({ + where, + skip, + take: limit, + }), + this.prisma.icdCode.count({ where }), + ]); + + return { + data, + total, + page, + limit, + totalPages: Math.ceil(total / limit), + }; + } +} diff --git a/src/icd/icd.controller.ts b/src/icd/icd.controller.ts index d0e0e8d..83f2d72 100644 --- a/src/icd/icd.controller.ts +++ b/src/icd/icd.controller.ts @@ -13,13 +13,17 @@ import { IcdStatisticsResponseDto, ErrorResponseDto, } from './dto/icd-response.dto'; +import { ResponseService } from 'src/common/response/response/response.service'; @ApiTags('ICD') @Controller('icd') export class IcdController { private readonly logger = new Logger(IcdController.name); - constructor(private readonly icdService: IcdService) {} + constructor( + private readonly icdService: IcdService, + private readonly responseService: ResponseService, + ) {} @Get('search') @ApiOperation({ @@ -72,7 +76,7 @@ export class IcdController { @Query('search') search?: string, @Query('page') page?: string, @Query('limit') limit?: string, - ): Promise { + ) { try { const pageNum = page ? parseInt(page, 10) : 1; const limitNum = limit ? parseInt(limit, 10) : 10; @@ -84,21 +88,15 @@ export class IcdController { limitNum, ); - return { - success: true, - data: result.data, - pagination: { - currentPage: result.page, - totalPages: result.totalPages, - totalItems: result.total, - itemsPerPage: result.limit, - hasNextPage: result.page < result.totalPages, - hasPreviousPage: result.page > 1, - }, - }; + return this.responseService.paginate( + result.data, + result.page, + result.limit, + result.total, + ); } catch (error) { this.logger.error('Error searching ICD codes:', error); - throw error; + return this.responseService.error('Internal server error during search'); } } @@ -117,16 +115,16 @@ export class IcdController { description: 'Internal server error while fetching statistics', type: ErrorResponseDto, }) - async getStatistics(): Promise { + async getStatistics() { try { const stats = await this.icdService.getStatistics(); - return { - success: true, - data: stats, - }; + + return this.responseService.success(stats); } catch (error) { this.logger.error('Error getting statistics:', error); - throw error; + return this.responseService.error( + 'Internal server error while fetching statistics', + ); } } } diff --git a/src/icd/icd.module.ts b/src/icd/icd.module.ts index 5d15022..037b71f 100644 --- a/src/icd/icd.module.ts +++ b/src/icd/icd.module.ts @@ -2,11 +2,13 @@ import { Module } from '@nestjs/common'; import { IcdController } from './icd.controller'; import { IcdService } from './icd.service'; import { PgVectorModule } from './pgvector.module'; +import { PrismaModule } from 'src/common/prisma/prisma.module'; +import { ResponseModule } from 'src/common/response/response.module'; @Module({ controllers: [IcdController], providers: [IcdService], - imports: [PgVectorModule], + imports: [PgVectorModule, PrismaModule, ResponseModule], exports: [IcdService, PgVectorModule], }) export class IcdModule {} diff --git a/src/icd/icd.service.ts b/src/icd/icd.service.ts index 8eafe4b..5503676 100644 --- a/src/icd/icd.service.ts +++ b/src/icd/icd.service.ts @@ -1,10 +1,10 @@ import { Injectable, Logger } from '@nestjs/common'; -import { PrismaClient } from '@prisma/client'; +import { PrismaService } from 'src/common/prisma/prisma/prisma.service'; @Injectable() export class IcdService { private readonly logger = new Logger(IcdService.name); - private readonly prisma = new PrismaClient(); + constructor(private readonly prisma: PrismaService) {} async findIcdCodes( category?: string, @@ -79,8 +79,4 @@ export class IcdService { throw error; } } - - async onModuleDestroy() { - await this.prisma.$disconnect(); - } }