try fix similaruuty and add seed for master excel icd

This commit is contained in:
arifal
2025-08-23 03:25:15 +07:00
parent b77beb2d85
commit 0ad656ce35
14 changed files with 1274 additions and 223 deletions

View File

@@ -64,7 +64,7 @@ export class HealthController {
status: 200,
description: 'Application is ready',
})
async getReady() {
getReady() {
return { status: 'ready' };
}
@@ -77,7 +77,7 @@ export class HealthController {
status: 200,
description: 'Application is alive',
})
async getLive() {
getLive() {
return { status: 'alive' };
}
}

View File

@@ -1,4 +1,4 @@
import { Controller, Get, Post, Query, Logger } from '@nestjs/common';
import { Controller, Get, Query, Logger } from '@nestjs/common';
import {
ApiTags,
ApiOperation,
@@ -8,10 +8,8 @@ import {
ApiInternalServerErrorResponse,
} from '@nestjs/swagger';
import { IcdService } from './icd.service';
import { SearchIcdDto } from './dto/search-icd.dto';
import {
IcdSearchResponseDto,
IcdImportResponseDto,
IcdStatisticsResponseDto,
ErrorResponseDto,
} from './dto/icd-response.dto';
@@ -23,40 +21,6 @@ export class IcdController {
constructor(private readonly icdService: IcdService) {}
@Post('import')
@ApiOperation({
summary: 'Import ICD data from Excel files',
description:
'Import ICD-9 and ICD-10 codes from Excel files located in the test directory. This operation will process both ICD files and insert/update the database with the latest codes.',
})
@ApiResponse({
status: 200,
description: 'ICD data imported successfully',
type: IcdImportResponseDto,
})
@ApiBadRequestResponse({
description: 'Bad request - Invalid file format or missing files',
type: ErrorResponseDto,
})
@ApiInternalServerErrorResponse({
description: 'Internal server error during import process',
type: ErrorResponseDto,
})
async importData(): Promise<IcdImportResponseDto> {
try {
this.logger.log('Starting ICD data import...');
const result = await this.icdService.importIcdData();
return {
success: true,
message: 'ICD data imported successfully',
data: result,
};
} catch (error) {
this.logger.error('Error importing ICD data:', error);
throw error;
}
}
@Get('search')
@ApiOperation({
summary: 'Search ICD codes with filters and pagination',

View File

@@ -1,158 +1,11 @@
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '../../generated/prisma';
import * as XLSX from 'xlsx';
import * as path from 'path';
import * as fs from 'fs';
interface IcdData {
code: string;
display: string;
version: string;
}
import { PrismaClient } from '@prisma/client';
@Injectable()
export class IcdService {
private readonly logger = new Logger(IcdService.name);
private readonly prisma = new PrismaClient();
async importIcdData(): Promise<{
icd9Count: number;
icd10Count: number;
total: number;
}> {
try {
this.logger.log('Starting ICD data import...');
// Import ICD-9 data
const icd9Data = await this.readExcelFile(
'test/[PUBLIC] ICD-9CM e-klaim.xlsx',
'ICD9',
);
// Import ICD-10 data
const icd10Data = await this.readExcelFile(
'test/[PUBLIC] ICD-10 e-klaim.xlsx',
'ICD10',
);
// Clear existing data
await this.prisma.icdCode.deleteMany({});
this.logger.log('Cleared existing ICD data');
// Insert ICD-9 data
const icd9Count = await this.bulkInsertData(icd9Data, 'ICD9');
this.logger.log(`Imported ${icd9Count} ICD-9 codes`);
// Insert ICD-10 data
const icd10Count = await this.bulkInsertData(icd10Data, 'ICD10');
this.logger.log(`Imported ${icd10Count} ICD-10 codes`);
const total = icd9Count + icd10Count;
this.logger.log(`Total imported: ${total} ICD codes`);
return {
icd9Count,
icd10Count,
total,
};
} catch (error) {
this.logger.error('Error importing ICD data:', error);
throw error;
}
}
private async readExcelFile(
filePath: string,
category: string,
): Promise<IcdData[]> {
try {
const fullPath = path.join(process.cwd(), filePath);
if (!fs.existsSync(fullPath)) {
throw new Error(`File not found: ${fullPath}`);
}
this.logger.log(`Reading ${category} file: ${filePath}`);
const workbook = XLSX.readFile(fullPath);
const sheetName = workbook.SheetNames[0];
const worksheet = workbook.Sheets[sheetName];
// Convert sheet to JSON
const jsonData = XLSX.utils.sheet_to_json(worksheet, { header: 1 });
// Skip header row and process data
const icdData: IcdData[] = [];
for (let i = 1; i < jsonData.length; i++) {
const row = jsonData[i] as any[];
if (row && row.length >= 3) {
const code = this.cleanString(row[0]);
const display = this.cleanString(row[1]);
const version = this.cleanString(row[2]);
if (code && display && version) {
icdData.push({
code,
display,
version,
});
}
}
}
this.logger.log(`Found ${icdData.length} valid ${category} records`);
return icdData;
} catch (error) {
this.logger.error(`Error reading ${category} file:`, error);
throw error;
}
}
private async bulkInsertData(
data: IcdData[],
category: string,
): Promise<number> {
try {
const batchSize = 1000;
let totalInserted = 0;
for (let i = 0; i < data.length; i += batchSize) {
const batch = data.slice(i, i + batchSize);
const insertData = batch.map((item) => ({
code: item.code,
display: item.display,
version: item.version,
category,
}));
await this.prisma.icdCode.createMany({
data: insertData,
skipDuplicates: true,
});
totalInserted += batch.length;
this.logger.log(
`Inserted batch ${Math.floor(i / batchSize) + 1} for ${category}: ${batch.length} records`,
);
}
return totalInserted;
} catch (error) {
this.logger.error(`Error inserting ${category} data:`, error);
throw error;
}
}
private cleanString(value: any): string {
if (value === null || value === undefined) {
return '';
}
return String(value).trim();
}
async findIcdCodes(
category?: string,
search?: string,

View File

@@ -8,6 +8,15 @@ import {
ValidationPipe,
UsePipes,
} from '@nestjs/common';
import {
IsString,
IsOptional,
IsNumber,
IsEnum,
Min,
Max,
IsNotEmpty,
} from 'class-validator';
import {
ApiTags,
ApiOperation,
@@ -27,6 +36,8 @@ export class VectorSearchDto {
minLength: 1,
maxLength: 500,
})
@IsString()
@IsNotEmpty()
query: string;
@ApiProperty({
@@ -37,6 +48,10 @@ export class VectorSearchDto {
maximum: 100,
default: 10,
})
@IsOptional()
@IsNumber()
@Min(1)
@Max(100)
limit?: number;
@ApiProperty({
@@ -46,16 +61,22 @@ export class VectorSearchDto {
enum: ['ICD9', 'ICD10'],
default: undefined,
})
@IsOptional()
@IsEnum(['ICD9', 'ICD10'])
category?: string;
@ApiProperty({
description: 'Similarity threshold (0.0 - 1.0) for filtering results',
example: 0.7,
example: 0.85,
required: false,
minimum: 0.0,
maximum: 1.0,
default: 0.7,
default: 0.85,
})
@IsOptional()
@IsNumber()
@Min(0.0)
@Max(1.0)
threshold?: number;
}
@@ -66,6 +87,8 @@ export class EmbeddingRequestDto {
minLength: 1,
maxLength: 1000,
})
@IsString()
@IsNotEmpty()
text: string;
@ApiProperty({
@@ -74,9 +97,24 @@ export class EmbeddingRequestDto {
required: false,
default: 'text-embedding-ada-002',
})
@IsOptional()
@IsString()
model?: string;
}
export class ThresholdConfigDto {
@ApiProperty({
description: 'Similarity threshold value (0.0 - 1.0)',
example: 0.85,
minimum: 0.0,
maximum: 1.0,
})
@IsNumber()
@Min(0.0)
@Max(1.0)
threshold: number;
}
export class VectorSearchResponseDto {
@ApiProperty({
description: 'Array of search results with similarity scores',
@@ -486,6 +524,61 @@ export class PgVectorController {
};
}
@Post('advanced-search')
@ApiOperation({
summary: 'Advanced vector similarity search',
description:
'Advanced vector search using multiple similarity metrics (cosine + euclidean) for more accurate results with higher threshold.',
tags: ['PgVector Operations'],
})
@ApiConsumes('application/json')
@ApiProduces('application/json')
@ApiBody({
type: VectorSearchDto,
description: 'Search parameters for advanced vector search',
examples: {
highPrecision: {
summary: 'High precision search',
value: {
query: 'diabetes mellitus type 2',
limit: 10,
category: 'ICD10',
threshold: 0.9,
},
},
},
})
@ApiResponse({
status: HttpStatus.OK,
description:
'Advanced vector search results with enhanced similarity scores',
type: VectorSearchResponseDto,
})
@ApiResponse({
status: HttpStatus.BAD_REQUEST,
description: 'Invalid search parameters',
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error during advanced vector search',
})
async advancedVectorSearch(
@Body() searchDto: VectorSearchDto,
): Promise<VectorSearchResponseDto> {
const results = await this.pgVectorService.advancedVectorSearch(
searchDto.query,
searchDto.limit || 10,
searchDto.category,
searchDto.threshold,
);
return {
data: results,
total: results.length,
query: searchDto.query,
};
}
@Post('generate-embedding')
@ApiOperation({
summary: 'Generate text embedding',
@@ -570,6 +663,50 @@ export class PgVectorController {
};
}
@Post('regenerate-embeddings-enhanced')
@ApiOperation({
summary: 'Regenerate embeddings with enhanced text representation',
description:
'Regenerate existing embeddings using enhanced text representation for better similarity scores. This improves search quality.',
tags: ['PgVector Operations'],
})
@ApiConsumes('application/json')
@ApiProduces('application/json')
@ApiResponse({
status: HttpStatus.OK,
description: 'Enhanced embedding regeneration results summary',
schema: {
type: 'object',
properties: {
processed: { type: 'number', example: 100 },
errors: { type: 'number', example: 0 },
totalSample: { type: 'number', example: 100 },
message: {
type: 'string',
example: 'Enhanced embeddings regenerated successfully',
},
},
},
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error during enhanced embedding regeneration',
})
async regenerateEmbeddingsEnhanced(): Promise<{
processed: number;
errors: number;
totalSample: number;
message: string;
}> {
const result =
await this.pgVectorService.regenerateEmbeddingsWithEnhancedText();
return {
...result,
message: `Enhanced embeddings regenerated successfully. Processed: ${result.processed}, Errors: ${result.errors}`,
};
}
@Get('stats')
@ApiOperation({
summary: 'Get embedding statistics',
@@ -640,6 +777,234 @@ export class PgVectorController {
};
}
@Get('threshold')
@ApiOperation({
summary: 'Get current similarity threshold',
description:
'Get the current similarity threshold configuration used for vector search filtering.',
tags: ['PgVector Operations'],
})
@ApiProduces('application/json')
@ApiResponse({
status: HttpStatus.OK,
description: 'Current similarity threshold configuration',
schema: {
type: 'object',
properties: {
threshold: {
type: 'number',
description: 'Current similarity threshold value',
example: 0.85,
},
description: {
type: 'string',
description: 'Description of the threshold setting',
example: 'Minimum similarity score required for search results',
},
},
},
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error retrieving threshold configuration',
})
async getSimilarityThreshold(): Promise<{
threshold: number;
description: string;
}> {
const threshold = this.pgVectorService.getSimilarityThreshold();
return {
threshold,
description:
'Minimum similarity score required for search results (0.0 - 1.0)',
};
}
@Get('model')
@ApiOperation({
summary: 'Get current embedding model',
description:
'Get the current OpenAI embedding model configuration used for vector generation.',
tags: ['PgVector Operations'],
})
@ApiProduces('application/json')
@ApiResponse({
status: HttpStatus.OK,
description: 'Current embedding model configuration',
schema: {
type: 'object',
properties: {
model: {
type: 'string',
description: 'Current embedding model name',
example: 'text-embedding-ada-002',
},
description: {
type: 'string',
description: 'Description of the model configuration',
example: 'OpenAI embedding model for vector generation',
},
source: {
type: 'string',
description: 'Source of the model configuration',
example: 'Environment Variable',
},
},
},
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error retrieving model configuration',
})
async getEmbeddingModel(): Promise<{
model: string;
description: string;
source: string;
}> {
const model = this.pgVectorService.getEmbeddingModel();
const source = process.env.OPENAI_API_MODEL
? 'Environment Variable'
: 'Default';
return {
model,
description: 'OpenAI embedding model for vector generation',
source,
};
}
@Post('threshold')
@ApiOperation({
summary: 'Set similarity threshold',
description:
'Set the similarity threshold for vector search filtering. Higher values result in more strict matching.',
tags: ['PgVector Operations'],
})
@ApiConsumes('application/json')
@ApiProduces('application/json')
@ApiBody({
type: ThresholdConfigDto,
description: 'Threshold configuration parameters',
})
@ApiResponse({
status: HttpStatus.OK,
description: 'Similarity threshold updated successfully',
schema: {
type: 'object',
properties: {
message: {
type: 'string',
description: 'Success message',
example: 'Similarity threshold updated successfully',
},
threshold: {
type: 'number',
description: 'Updated threshold value',
example: 0.9,
},
previousThreshold: {
type: 'number',
description: 'Previous threshold value',
example: 0.85,
},
},
},
})
@ApiResponse({
status: HttpStatus.BAD_REQUEST,
description: 'Invalid threshold value (must be between 0.0 and 1.0)',
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error updating threshold configuration',
})
async setSimilarityThreshold(
@Body() thresholdConfig: ThresholdConfigDto,
): Promise<{
message: string;
threshold: number;
previousThreshold: number;
}> {
const previousThreshold = this.pgVectorService.getSimilarityThreshold();
this.pgVectorService.setSimilarityThreshold(thresholdConfig.threshold);
return {
message: 'Similarity threshold updated successfully',
threshold: thresholdConfig.threshold,
previousThreshold,
};
}
@Post('model')
@ApiOperation({
summary: 'Set embedding model',
description:
'Set the OpenAI embedding model for vector generation. This will reinitialize the embeddings service.',
tags: ['PgVector Operations'],
})
@ApiConsumes('application/json')
@ApiProduces('application/json')
@ApiBody({
schema: {
type: 'object',
properties: {
model: {
type: 'string',
description: 'OpenAI embedding model name',
example: 'text-embedding-ada-002',
},
},
required: ['model'],
},
description: 'Model configuration parameters',
})
@ApiResponse({
status: HttpStatus.OK,
description: 'Embedding model updated successfully',
schema: {
type: 'object',
properties: {
message: {
type: 'string',
description: 'Success message',
example: 'Embedding model updated successfully',
},
model: {
type: 'string',
description: 'Updated model name',
example: 'text-embedding-ada-002',
},
previousModel: {
type: 'string',
description: 'Previous model name',
example: 'text-embedding-ada-002',
},
},
},
})
@ApiResponse({
status: HttpStatus.BAD_REQUEST,
description: 'Invalid model name',
})
@ApiResponse({
status: HttpStatus.INTERNAL_SERVER_ERROR,
description: 'Error updating model configuration',
})
async setEmbeddingModel(@Body() body: { model: string }): Promise<{
message: string;
model: string;
previousModel: string;
}> {
const previousModel = this.pgVectorService.getEmbeddingModel();
await this.pgVectorService.setEmbeddingModel(body.model);
return {
message: 'Embedding model updated successfully',
model: body.model,
previousModel,
};
}
@Post('refresh')
@ApiOperation({
summary: 'Refresh pgvector store',

View File

@@ -1,8 +1,7 @@
import { Injectable, Logger } from '@nestjs/common';
import { PrismaClient } from '../../generated/prisma';
import { PrismaClient } from '@prisma/client';
import { OpenAIEmbeddings } from '@langchain/openai';
import { PGVectorStore } from '@langchain/community/vectorstores/pgvector';
import { Document } from 'langchain/document';
import { Pool } from 'pg';
export interface VectorSearchResult {
@@ -72,6 +71,41 @@ export class PgVectorService {
}
}
/**
* Reinitialize OpenAI embeddings with new model
*/
private async reinitializeEmbeddings(modelName: string): Promise<void> {
try {
const apiKey = process.env.OPENAI_API_KEY;
if (!apiKey) {
throw new Error('OPENAI_API_KEY not found');
}
this.logger.log(
`Reinitializing OpenAI embeddings with model: ${modelName}`,
);
// Create new embeddings instance with new model
this.embeddings = new OpenAIEmbeddings({
openAIApiKey: apiKey,
modelName: modelName,
maxConcurrency: 5,
});
// Update environment variable to reflect current model
process.env.OPENAI_API_MODEL = modelName;
this.logger.log(
`OpenAI embeddings reinitialized successfully with model: ${modelName}`,
);
} catch (error) {
this.logger.error('Failed to reinitialize OpenAI embeddings:', error);
throw new Error(
`Failed to reinitialize OpenAI embeddings: ${error.message}`,
);
}
}
/**
* Initialize pgvector store dengan LangChain
*/
@@ -115,25 +149,35 @@ export class PgVectorService {
/**
* Generate embedding untuk text menggunakan OpenAI
*/
async generateEmbedding(
text: string,
model: string = 'text-embedding-ada-002',
): Promise<number[]> {
async generateEmbedding(text: string, model?: string): Promise<number[]> {
try {
// Get model from parameter, environment variable, or use default
const embeddingModel =
model || process.env.OPENAI_API_MODEL || 'text-embedding-ada-002';
this.logger.log(
`Generating embedding for text: ${text.substring(0, 100)}...`,
`Generating embedding for text: ${text.substring(0, 100)}... using model: ${embeddingModel}`,
);
// Check if we need to reinitialize embeddings with new model
const currentModel = this.getEmbeddingModel();
if (model && model !== currentModel) {
this.logger.log(
`Switching embedding model from ${currentModel} to ${model}`,
);
await this.reinitializeEmbeddings(model);
}
if (!this.embeddings) {
throw new Error(
'OpenAI embeddings not initialized. Please check your API configuration.',
);
}
// Use OpenAI embeddings
// Use OpenAI embeddings with current model
const embedding = await this.embeddings.embedQuery(text);
this.logger.log(
`Generated OpenAI embedding with ${embedding.length} dimensions`,
`Generated OpenAI embedding with ${embedding.length} dimensions using model: ${this.getEmbeddingModel()}`,
);
return embedding;
} catch (error) {
@@ -191,7 +235,8 @@ export class PgVectorService {
`UPDATE icd_codes
SET embedding = $1::vector,
metadata = $2::jsonb,
content = $3
content = $3,
"updatedAt" = NOW()
WHERE id = $4`,
[
vectorString,
@@ -289,7 +334,8 @@ export class PgVectorService {
`UPDATE icd_codes
SET embedding = $1::vector,
metadata = $2::jsonb,
content = $3
content = $3,
"updatedAt" = NOW()
WHERE id = $4`,
[
vectorString,
@@ -337,16 +383,23 @@ export class PgVectorService {
}
/**
* Vector similarity search menggunakan pgvector
* Vector similarity search menggunakan pgvector dengan threshold yang dapat dikonfigurasi
*/
async vectorSearch(
query: string,
limit: number = 10,
category?: string,
threshold: number = 0.7,
threshold?: number,
): Promise<VectorSearchResult[]> {
// Get threshold from environment variable or use default
const defaultThreshold = parseFloat(
process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85',
);
const similarityThreshold = threshold || defaultThreshold;
try {
this.logger.log(`Performing pgvector search for: ${query}`);
this.logger.log(
`Performing pgvector search for: ${query} with threshold: ${similarityThreshold}`,
);
if (!this.embeddings) {
throw new Error('OpenAI embeddings not initialized');
@@ -358,17 +411,19 @@ export class PgVectorService {
// Convert embedding array to proper vector format for pgvector
const vectorString = `[${queryEmbedding.join(',')}]`;
// Build SQL query for vector similarity search
// Build SQL query for vector similarity search with higher precision
// Using cosine distance and converting to similarity score
let sql = `
SELECT
id, code, display, version, category,
1 - (embedding <=> $1::vector) as similarity
(1 - (embedding <=> $1::vector)) as similarity
FROM icd_codes
WHERE embedding IS NOT NULL
AND (1 - (embedding <=> $1::vector)) >= $2
`;
const params: any[] = [vectorString];
let paramIndex = 2;
const params: any[] = [vectorString, similarityThreshold];
let paramIndex = 3;
if (category) {
sql += ` AND category = $${paramIndex}`;
@@ -376,23 +431,24 @@ export class PgVectorService {
paramIndex++;
}
sql += ` ORDER BY embedding <=> $1::vector ASC LIMIT $${paramIndex}`;
// Order by similarity descending and limit results
sql += ` ORDER BY similarity DESC LIMIT $${paramIndex}`;
params.push(limit);
// Execute raw SQL query
const result = await this.pool.query(sql, params);
// Transform and filter results
const filteredResults: VectorSearchResult[] = result.rows
.filter((row: any) => row.similarity >= threshold)
.map((row: any) => ({
// Transform results (no need to filter again since SQL already filters)
const filteredResults: VectorSearchResult[] = result.rows.map(
(row: any) => ({
id: row.id,
code: row.code,
display: row.display,
version: row.version,
category: row.category,
similarity: parseFloat(row.similarity),
}));
}),
);
this.logger.log(
`Pgvector search returned ${filteredResults.length} results for query: "${query}"`,
@@ -417,12 +473,12 @@ export class PgVectorService {
try {
this.logger.log(`Performing hybrid search for: ${query}`);
// Get vector search results
// Get vector search results with higher threshold
const vectorResults = await this.vectorSearch(
query,
limit * 2,
category,
0.5,
parseFloat(process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85'),
);
// Get text search results
@@ -475,7 +531,7 @@ export class PgVectorService {
try {
let sql = 'SELECT id, code, display, version, category FROM icd_codes';
const params: any[] = [];
let whereConditions: string[] = [];
const whereConditions: string[] = [];
let paramIndex = 1;
if (category) {
@@ -515,6 +571,51 @@ export class PgVectorService {
}
}
/**
* Get current similarity threshold configuration
*/
getSimilarityThreshold(): number {
return parseFloat(process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85');
}
/**
* Get current embedding model configuration
*/
getEmbeddingModel(): string {
return process.env.OPENAI_API_MODEL || 'text-embedding-ada-002';
}
/**
* Set similarity threshold (for runtime configuration)
*/
setSimilarityThreshold(threshold: number): void {
if (threshold < 0 || threshold > 1) {
throw new Error('Similarity threshold must be between 0 and 1');
}
process.env.VECTOR_SIMILARITY_THRESHOLD = threshold.toString();
this.logger.log(`Similarity threshold updated to: ${threshold}`);
}
/**
* Set embedding model (for runtime configuration)
*/
async setEmbeddingModel(modelName: string): Promise<void> {
if (!modelName || typeof modelName !== 'string') {
throw new Error('Model name must be a valid string');
}
const currentModel = this.getEmbeddingModel();
if (modelName === currentModel) {
this.logger.log(`Model ${modelName} is already active`);
return;
}
this.logger.log(
`Switching embedding model from ${currentModel} to ${modelName}`,
);
await this.reinitializeEmbeddings(modelName);
}
/**
* Get embedding statistics
*/
@@ -524,6 +625,7 @@ export class PgVectorService {
withoutEmbeddings: number;
percentage: number;
vectorStoreStatus: string;
currentThreshold: number;
}> {
try {
// Use raw SQL to get embedding statistics
@@ -548,6 +650,7 @@ export class PgVectorService {
withoutEmbeddings,
percentage: Math.round(percentage * 100) / 100,
vectorStoreStatus,
currentThreshold: this.getSimilarityThreshold(),
};
} catch (error) {
this.logger.error('Error getting embedding stats:', error);
@@ -569,6 +672,95 @@ export class PgVectorService {
}
}
/**
* Advanced vector search dengan multiple similarity metrics untuk mendapatkan hasil yang lebih akurat
*/
async advancedVectorSearch(
query: string,
limit: number = 10,
category?: string,
threshold?: number,
): Promise<VectorSearchResult[]> {
try {
// Get threshold from environment variable or use default
const defaultThreshold = parseFloat(
process.env.VECTOR_SIMILARITY_THRESHOLD || '0.85',
);
const similarityThreshold = threshold || defaultThreshold;
this.logger.log(
`Performing advanced vector search for: ${query} with threshold: ${similarityThreshold}`,
);
if (!this.embeddings) {
throw new Error('OpenAI embeddings not initialized');
}
// Generate embedding for query
const queryEmbedding = await this.generateEmbedding(query);
const vectorString = `[${queryEmbedding.join(',')}]`;
// Advanced SQL query using multiple similarity metrics
let sql = `
SELECT
id, code, display, version, category,
(1 - (embedding <=> $1::vector)) as cosine_similarity,
(1 - (embedding <-> $1::vector)) as euclidean_similarity,
(embedding <#> $1::vector) as negative_inner_product
FROM icd_codes
WHERE embedding IS NOT NULL
`;
const params: any[] = [vectorString];
let paramIndex = 2;
if (category) {
sql += ` AND category = $${paramIndex}`;
params.push(category);
paramIndex++;
}
// Filter by cosine similarity threshold
sql += ` AND (1 - (embedding <=> $1::vector)) >= $${paramIndex}`;
params.push(similarityThreshold);
paramIndex++;
// Order by combined similarity score and limit
sql += ` ORDER BY cosine_similarity DESC, euclidean_similarity DESC LIMIT $${paramIndex}`;
params.push(limit);
const result = await this.pool.query(sql, params);
// Transform results with enhanced similarity scoring
const filteredResults: VectorSearchResult[] = result.rows.map(
(row: any) => {
const cosineSim = parseFloat(row.cosine_similarity);
const euclideanSim = parseFloat(row.euclidean_similarity);
// Calculate combined similarity score (weighted average)
const combinedSimilarity = cosineSim * 0.7 + euclideanSim * 0.3;
return {
id: row.id,
code: row.code,
display: row.display,
version: row.version,
category: row.category,
similarity: Math.round(combinedSimilarity * 1000) / 1000, // Round to 3 decimal places
};
},
);
this.logger.log(
`Advanced vector search returned ${filteredResults.length} results for query: "${query}" with threshold: ${similarityThreshold}`,
);
return filteredResults;
} catch (error) {
this.logger.error('Error in advanced vector search:', error);
throw error;
}
}
/**
* Get vector store status
*/
@@ -589,9 +781,10 @@ export class PgVectorService {
initialized: !!this.vectorStore,
documentCount,
embeddingModel: this.embeddings
? `OpenAI ${process.env.OPENAI_API_MODEL || 'text-embedding-ada-002'}`
? `OpenAI ${this.getEmbeddingModel()}`
: 'Not Available',
lastUpdated: new Date(),
currentThreshold: this.getSimilarityThreshold(),
};
return status;
@@ -601,6 +794,164 @@ export class PgVectorService {
}
}
/**
* Create enhanced text representation for better embedding quality
*/
private createEnhancedTextRepresentation(code: any): string {
// Base text with code and display
let text = `${code.code} ${code.display}`;
// Add category context
if (code.category === 'ICD9') {
text += ` ICD-9 CM procedure diagnosis`;
} else if (code.category === 'ICD10') {
text += ` ICD-10 diagnosis condition`;
}
// Add version context
if (code.version) {
text += ` ${code.version}`;
}
// Add medical context based on display content
const display = code.display.toLowerCase();
// Add procedure context
if (
display.includes('procedure') ||
display.includes('surgery') ||
display.includes('operation')
) {
text += ' medical procedure surgical intervention';
}
// Add diagnosis context
if (
display.includes('diagnosis') ||
display.includes('condition') ||
display.includes('disease')
) {
text += ' medical diagnosis clinical condition';
}
// Add anatomical context
if (
display.includes('cranial') ||
display.includes('brain') ||
display.includes('head')
) {
text += ' neurological cranial brain head';
}
if (
display.includes('cardiac') ||
display.includes('heart') ||
display.includes('cardiovascular')
) {
text += ' cardiac heart cardiovascular';
}
if (
display.includes('pulmonary') ||
display.includes('lung') ||
display.includes('respiratory')
) {
text += ' pulmonary lung respiratory';
}
// Add common medical terms
text += ' medical healthcare clinical';
return text;
}
/**
* Regenerate embeddings with enhanced text representation for better similarity
*/
async regenerateEmbeddingsWithEnhancedText(limit: number = 100): Promise<{
processed: number;
errors: number;
totalSample: number;
}> {
try {
this.logger.log(
`Starting enhanced embedding regeneration for ${limit} ICD codes...`,
);
// Get ICD codes with existing embeddings to regenerate
const codesWithEmbeddings = await this.pool.query(
'SELECT id, code, display, version, category FROM icd_codes WHERE embedding IS NOT NULL LIMIT $1',
[limit],
);
if (codesWithEmbeddings.rows.length === 0) {
this.logger.log('No ICD codes found with embeddings to regenerate');
return { processed: 0, errors: 0, totalSample: 0 };
}
this.logger.log(
`Found ${codesWithEmbeddings.rows.length} codes to regenerate with enhanced text`,
);
let processed = 0;
let errors = 0;
// Process each code
for (let i = 0; i < codesWithEmbeddings.rows.length; i++) {
const code = codesWithEmbeddings.rows[i];
try {
// Create enhanced text representation for better embedding quality
const text = this.createEnhancedTextRepresentation(code);
// Generate new embedding
const embedding = await this.generateEmbedding(text);
// Convert embedding array to proper vector format for pgvector
const vectorString = `[${embedding.join(',')}]`;
// Update database with new embedding and enhanced content
await this.pool.query(
`UPDATE icd_codes
SET embedding = $1::vector,
content = $2,
"updatedAt" = NOW()
WHERE id = $3`,
[vectorString, text, code.id],
);
processed++;
if (processed % 10 === 0) {
this.logger.log(
`Regenerated ${processed}/${codesWithEmbeddings.rows.length} enhanced embeddings`,
);
}
} catch (error) {
this.logger.error(
`Error regenerating embedding for code ${code.code}:`,
error,
);
errors++;
}
}
this.logger.log(
`Enhanced embedding regeneration completed. Processed: ${processed}, Errors: ${errors}, Total: ${codesWithEmbeddings.rows.length}`,
);
return {
processed,
errors,
totalSample: codesWithEmbeddings.rows.length,
};
} catch (error) {
this.logger.error(
'Error in regenerateEmbeddingsWithEnhancedText:',
error,
);
throw error;
}
}
/**
* Cleanup resources
*/