[jblas] 01/02: Imported Upstream version 1.2.3
Tony Mancill
tmancill at moszumanska.debian.org
Fri Jan 10 04:41:43 UTC 2014
This is an automated email from the git hooks/post-receive script.
tmancill pushed a commit to branch master
in repository jblas.
commit 9c144335cf64eeda4da5e96efb7ce261cf401e7f
Author: tony mancill <tmancill at debian.org>
Date: Thu Jan 9 20:40:44 2014 -0800
Imported Upstream version 1.2.3
---
.gitignore | 9 +
.travis.yml | 8 +
AUTHORS | 5 +-
Makefile | 7 +-
README => README.md | 32 +-
RELEASE_NOTES | 45 +
ROADMAP | 11 +
build.xml | 22 +-
config/config.rb | 4 +-
config/config_cc.rb | 39 +-
config/config_fortran.rb | 18 +-
config/config_java.rb | 10 +-
config/config_lapack_sources.rb | 6 +-
config/config_libs.rb | 48 +-
config/config_make.rb | 4 +-
config/config_os_arch.rb | 2 +-
config/config_tools.rb | 2 +-
config/configure.rb | 5 +-
config/lib_helpers.rb | 32 +-
config/path.rb | 10 +-
config/windows.rb | 3 +-
fortranwrapper.dump | Bin 22347 -> 25090 bytes
pom.xml | 189 ++--
scripts/class_to_float.rb | 20 +-
scripts/fortran/java.rb | 17 +
scripts/fortran/parser.rb | 1 +
scripts/fortran/types.rb | 3 +-
scripts/java-class.java | 20 +-
scripts/java-impl.c | 2 +
scripts/rjpp.rb | 17 +-
src/main/c/NativeBlas.c | 496 ++++++++-
src/main/c/jblas_arch_flavor.c | 2 +
src/main/c/org_jblas_NativeBlas.h | 59 +-
src/main/c/org_jblas_util_ArchFlavor.h | 1 -
src/main/java/org/jblas/ComplexDoubleMatrix.java | 72 +-
src/main/java/org/jblas/ComplexFloat.java | 2 +-
src/main/java/org/jblas/ComplexFloatMatrix.java | 74 +-
src/main/java/org/jblas/Decompose.java | 159 ++-
src/main/java/org/jblas/DoubleMatrix.java | 120 ++-
src/main/java/org/jblas/FloatFunction.java | 2 +-
src/main/java/org/jblas/FloatMatrix.java | 122 ++-
src/main/java/org/jblas/Info.java | 14 +
src/main/java/org/jblas/NativeBlas.java | 101 +-
.../java/org/jblas/NativeBlasLibraryLoader.java | 70 ++
src/main/java/org/jblas/SimpleBlas.java | 138 ++-
src/main/java/org/jblas/Singular.java | 134 ++-
src/main/java/org/jblas/Solve.java | 78 ++
src/main/java/org/jblas/benchmark/Main.java | 7 +-
.../UnsupportedArchitectureException.java | 16 +
src/main/java/org/jblas/ranges/AllRange.java | 74 +-
src/main/java/org/jblas/ranges/IntervalRange.java | 69 +-
src/main/java/org/jblas/ranges/PointRange.java | 65 +-
src/main/java/org/jblas/util/ArchFlavor.java | 5 +-
src/main/java/org/jblas/util/Functions.java | 6 +
src/main/java/org/jblas/util/LibraryLoader.java | 405 ++++----
src/main/java/org/jblas/util/Permutations.java | 30 +-
src/main/java/org/jblas/util/Random.java | 33 +
.../lib/static/Linux/amd64/libjblas_arch_flavor.so | Bin 7811 -> 6003 bytes
.../lib/static/Linux/amd64/sse2/libjblas.so | Bin 6240026 -> 0 bytes
.../lib/static/Linux/amd64/sse3/libjblas.so | Bin 4751359 -> 7735984 bytes
.../lib/static/Linux/i386/libjblas_arch_flavor.so | Bin 5413 -> 6871 bytes
.../lib/static/Linux/i386/sse2/libjblas.so | Bin 5353071 -> 0 bytes
.../lib/static/Linux/i386/sse3/libjblas.so | Bin 5578487 -> 5279230 bytes
.../Mac OS X/x86_64/libjblas_arch_flavor.jnilib | Bin 8496 -> 4356 bytes
.../static/Mac OS X/x86_64/sse3/libjblas.jnilib | Bin 7760504 -> 6121912 bytes
.../resources/lib/static/Windows/amd64/jblas.dll | Bin 5116529 -> 2166887 bytes
.../lib/static/Windows/amd64/jblas_arch_flavor.dll | Bin 121365 -> 250050 bytes
.../lib/static/Windows/amd64/libgcc_s_sjlj-1.dll | Bin 0 -> 99328 bytes
.../lib/static/Windows/amd64/libgfortran-3.dll | Bin 0 -> 937472 bytes
.../lib/static/Windows/x86/jblas_arch_flavor.dll | Bin 38184 -> 66580 bytes
.../lib/static/Windows/x86/libgcc_s_dw2-1.dll | Bin 0 -> 101902 bytes
.../lib/static/Windows/x86/libgfortran-3.dll | Bin 0 -> 787470 bytes
.../lib/static/Windows/x86/sse2/jblas.dll | Bin 5313362 -> 0 bytes
.../lib/static/Windows/x86/sse3/jblas.dll | Bin 5478336 -> 5672426 bytes
.../java/org/jblas/ComplexDoubleMatrixTest.java | 29 +-
.../java/org/jblas/JblasAssert.java} | 28 +-
src/test/java/org/jblas/SimpleBlasTest.java | 23 +-
src/test/java/org/jblas/TestBlasDouble.java | 306 +++---
src/test/java/org/jblas/TestBlasDoubleComplex.java | 10 +-
src/test/java/org/jblas/TestBlasFloat.java | 308 +++---
src/test/java/org/jblas/TestComplexFloat.java | 32 +-
src/test/java/org/jblas/TestDecompose.java | 94 ++
src/test/java/org/jblas/TestDoubleMatrix.java | 1049 ++++++++++---------
src/test/java/org/jblas/TestEigen.java | 71 +-
src/test/java/org/jblas/TestFloatMatrix.java | 1050 +++++++++++---------
src/test/java/org/jblas/TestGeometry.java | 82 +-
src/test/java/org/jblas/TestSingular.java | 81 ++
src/test/java/org/jblas/TestSolve.java | 114 +++
src/test/java/org/jblas/ranges/RangeTest.java | 44 +
89 files changed, 4242 insertions(+), 2024 deletions(-)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2349e18
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+# maven target repo
+
+target/
+
+# idea files
+.idea/
+*.ipr
+*.iml
+*.iws
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..ecf92a6
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,8 @@
+language: java
+notifications:
+ email: mikiobraun at gmail.com
+ on_success: never
+ on_failure: always
+
+before_install:
+ - sudo apt-get install -q libgfortran3
diff --git a/AUTHORS b/AUTHORS
index dcadd5e..32177c5 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -7,4 +7,7 @@ Additional Programming and Contributions:
Johannes Schaback
Jan Saputra Müller (exponential matrix multiplication, decomposition)
Matthias L. Jugel (packaging)
-Nicolas Oury (generalized eigenvectors)
\ No newline at end of file
+Nicolas Oury (generalized eigenvectors)
+http://github.com/cheshirekow (fixed with range objects)
+Quantisan (travis integration)
+robbymckilliam (fixes with complex SVD)
\ No newline at end of file
diff --git a/Makefile b/Makefile
index c888190..b0ae800 100644
--- a/Makefile
+++ b/Makefile
@@ -32,7 +32,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## --- END LICENSE BLOCK ---
-VERSION=1.2.0
+VERSION=1.2.1
######################################################################
#
@@ -140,7 +140,10 @@ generated-sources: \
$(LAPACK)/[sd]getrf.f \
$(LAPACK)/[sd]potrf.f \
$(LAPACK)/[sdcz]gesvd.f \
- $(LAPACK)/[sd]sygvd.f
+ $(LAPACK)/[sd]sygvd.f \
+ $(LAPACK)/[sd]gelsd.f \
+ $(LAPACK)/ilaenv.f \
+ $(LAPACK)/[sd]geqrf.f $(LAPACK)/[sd]ormqr.f
ant javah
touch $@
diff --git a/README b/README.md
similarity index 74%
rename from README
rename to README.md
index d1b8cc9..d837e81 100644
--- a/README
+++ b/README.md
@@ -1,23 +1,30 @@
jblas is a matrix library for Java which uses existing high
performance BLAS and LAPACK libraries like ATLAS.
-Version 1.2.0, January 7, 2011
-Version 1.1.1
-Version 1.1, August 16, 2010
-Version 1.0.2, February 26, 2010
-Version 1.0.1, January 14, 2010
-Version 1.0, December 22, 2009
-Version 0.3, September 17, 2009
-Version 0.2, May 8, 2009
-Version 0.1, March 28, 2009
+* Version 1.2.3, February 13, 2013
+* Version 1.2.2, December 17, 2012
+* Version 1.2.1
+* Version 1.2.0, January 7, 2011
+* Version 1.1.1
+* Version 1.1, August 16, 2010
+* Version 1.0.2, February 26, 2010
+* Version 1.0.1, January 14, 2010
+* Version 1.0, December 22, 2009
+* Version 0.3, September 17, 2009
+* Version 0.2, May 8, 2009
+* Version 0.1, March 28, 2009
see also the file RELEASE_NOTES
Homepage: http://jblas.org
+![travis status](https://travis-ci.org/mikiobraun/jblas.png)
+Travis Page: https://travis-ci.org/mikiobraun/jblas
+
INSTALL
+-------
-In principle, all you need is the jblas-1.2.0,jar in your
+In principle, all you need is the jblas-1.2.0.jar in your
classpath. jblas-1.2.0.jar will then automagically extract your platform
dependent native library to a tempfile and load it from there. You can
also put that file somewhere in your load path ($LD_LIBRARY_PATH for
@@ -25,6 +32,7 @@ Linux, %PATH for Windows).
BUILDING
+--------
If you only work on the java part, an ant build.xml is provided to
recompile the sources. In addition to that you need an installation of
@@ -45,6 +53,7 @@ further details.
HOW TO GET STARTED
+------------------
Have a look at javadoc/index.html and
javadoc/org/jblas/DoubleMatrix.html
@@ -55,12 +64,14 @@ in case, you only have the "client" JVM installed.
LICENSE
+-------
jblas is distributed under a BSD-style license. See the file COPYING
for more information.
BUGS
+----
If you encounter any bugs, feel free to go to http://jblas.org and
register a ticket for them. Make sure to include as much information
@@ -69,5 +80,6 @@ include the file "configure.log".
CONTRIBUTORS
+------------
see file AUTHORS
\ No newline at end of file
diff --git a/RELEASE_NOTES b/RELEASE_NOTES
index 3a9949a..9312234 100644
--- a/RELEASE_NOTES
+++ b/RELEASE_NOTES
@@ -1,3 +1,48 @@
+Release 1.2.3
+
+New features
+
+- LU decomposition for float
+- Least squares and Pseudo-inverse to Solve.
+- QR decomposition.
+- Removed dependency on external libgfortran for Windows.
+- Now also runs on CentOS and other Linux 2.6 distros out of the box.
+
+Removed features
+
+- Support for SSE2. I don't have access to such processors anymore. Sorry.
+
+Bug fixes
+
+- load and save didn't close the streams.
+- Bug in maxi and mini.
+- Recompiled on Linux 2.6.32 to resolve glibc version dependency
+ problems. Now also runs on centos.
+
+Release 1.2.2
+
+New features
+
+- full SVD for complex matrices
+- Added travis CI support on github
+- recompiled ATLAS for Linux, moved to 3.10.0, lapack 3.4.2
+- upload to central maven repo
+
+Bug fixes
+
+- More meaningful error messages for Linux/64 and Windows
+ when libraries are missing
+- SingularValueDecomposition for complex matrices used transpose()
+ instead of hermitian()
+
+Release 1.2.1
+
+Bug fixes and code side changes.
+
+- Fixed bugs with Range objects
+- Moved test code to JUnit 4.
+- Fixed configure script and code-generation to Ruby 1.9
+
Release 1.2.0 - January 7, 2011
- Added Generalized Eigenvalues for symmetric matrices (which you
diff --git a/ROADMAP b/ROADMAP
index a07f024..1fde85c 100644
--- a/ROADMAP
+++ b/ROADMAP
@@ -1,3 +1,14 @@
+Roadmap to jblas 1.2.1
+
+1. Linux / 32 bit / sse2
+2. Linux / 32 bit / sse3
+3. Linux / 64 bit / sse2
+4. Linux / 64 bit / sse3
+5. Mac OS X / 64 bit / sse3
+6. Windows / 32 bit / sse2
+7. Windows / 32 bit / sse3
+8. Windows / 64 bit / lapack lite only
+
Roadmap to jblas-0.3.1
- collect static builds for Windows(32), Linux(32/64), MacOSX(32/64).
diff --git a/build.xml b/build.xml
index 7ba2f9a..5598460 100644
--- a/build.xml
+++ b/build.xml
@@ -41,7 +41,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- Define directories -->
- <property name="version" value="1.2.0" />
+ <property name="version" value="1.2.3-SNAPSHOT" />
<property name="src" value="${basedir}/src/main/java" />
<property name="test" value="${basedir}/src/test/java" />
<property name="bin" value="${basedir}/target/classes" />
@@ -53,6 +53,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<property name="external" value="${basedir}/external" />
<property name="scripts" value="${basedir}/scripts" />
<property name="pkgbase" value="org.jblas" />
+ <property name="ruby" value="ruby" />
<!-- Macros -->
@@ -61,8 +62,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<attribute name="path" default="${src}"/>
<sequential>
<echo message="Generating float version of @{class}"/>
- <exec executable="ruby">
- <arg line="scripts/class_to_float.rb @{path} @{class}"/>
+ <exec executable="${ruby}">
+ <arg line="scripts/class_to_float.rb "@{path}" @{class}"/>
</exec>
</sequential>
</macrodef>
@@ -70,16 +71,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<attribute name="class"/>
<sequential>
<echo message="Add float versions to class @{class}"/>
- <exec executable="ruby">
- <arg line="scripts/static_class_to_float.rb ${src} @{class}"/>
+ <exec executable="${ruby}">
+ <arg line="scripts/static_class_to_float.rb "${src}" @{class}"/>
</exec>
</sequential>
</macrodef>
<macrodef name="rjpp" description="Run the ruby-java preprocessor.">
<attribute name="file"/>
<sequential>
- <exec executable="ruby">
- <arg line="scripts/rjpp.rb @{file}"/>
+ <exec executable="${ruby}">
+ <arg line="scripts/rjpp.rb "@{file}""/>
</exec>
</sequential>
</macrodef>
@@ -104,7 +105,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<attribute name="definitions" />
<attribute name="file" />
<sequential>
- <exec executable="ruby">
+ <exec executable="${ruby}">
<arg line="-I templates scripts/macro.rb templates/@{definitions}.rb templates/@{file} src/@{file}" />
</exec>
</sequential>
@@ -170,7 +171,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<static-class-to-float class="${pkgbase}.MatrixFunctions"/>
<static-class-to-float class="${pkgbase}.JavaBlas"/>
<static-class-to-float class="${pkgbase}.Singular"/>
- <!--<static-class-to-float class="${pkgbase}.Decompose"/>-->
</target>
<target name="preprocess" description="run the ruby preprocessor on necessary files">
<rjpp file="${src}/org/jblas/DoubleMatrix.java"/>
@@ -235,7 +235,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<![CDATA[ <h1>jblas - Linear Algebra for Java (version ${version})</h1> ]]>
</doctitle>
<bottom>
- <![CDATA[ © 2008-2010 by Mikio L. Braun and contributors ]]>
+ <![CDATA[ © 2008-2013 by Mikio L. Braun and contributors ]]>
</bottom>
</javadoc>
</target>
@@ -252,7 +252,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- <test name="${pkgbase}.TestDoubleMatrix" />
<test name="${pkgbase}.TestEigen" />
<test name="${pkgbase}.ranges.IntervalRangeTest" />
- <test name="${pkgbase}.DecomposeTest" /> -->
+ <test name="${pkgbase}.TestDecompose" /> -->
<!-- <batchtest fork="yes" todir="${reports.tests}">
<fileset dir="${test}">
<include name="**/*Test*.java"/>
diff --git a/config/config.rb b/config/config.rb
index f8048fe..4ed9fbf 100644
--- a/config/config.rb
+++ b/config/config.rb
@@ -34,7 +34,7 @@
require 'config/path'
-module Config
+module JblasConfig
class ConfigError < Exception
attr_reader :message
@@ -285,7 +285,7 @@ module Config
end
if __FILE__ == $0
- include Config
+ include JblasConfig
configure :say_hello do
puts "Hello"
diff --git a/config/config_cc.rb b/config/config_cc.rb
index 6871149..fc4e687 100644
--- a/config/config_cc.rb
+++ b/config/config_cc.rb
@@ -36,7 +36,7 @@ require 'config/path'
require 'config/config'
require 'config/config_java'
-include Config
+include JblasConfig
include Path
# Set up flags for different environments.
@@ -44,12 +44,12 @@ configure :cc => 'CC'
desc 'Setting up gcc and flags'
configure 'CC', 'CFLAGS' => ['OS_NAME', 'OS_ARCH', 'JAVA_HOME'] do
- os_name = Config::CONFIG['OS_NAME']
- java_home = Config::CONFIG['JAVA_HOME']
+ os_name = JblasConfig::CONFIG['OS_NAME']
+ java_home = JblasConfig::CONFIG['JAVA_HOME']
case os_name
when 'Linux'
Path.check_cmd('gcc', 'make', 'ld')
- Config::CONFIG << <<EOS
+ JblasConfig::CONFIG << <<EOS
CC = gcc
CFLAGS = -fPIC
INCDIRS += -Iinclude -I#{java_home}/include -I#{java_home}/include/linux
@@ -60,7 +60,7 @@ LDFLAGS += -shared
EOS
when 'SunOS'
Path.check_cmd('gcc', 'make', 'ld')
- Config::CONFIG << <<EOS
+ JblasConfig::CONFIG << <<EOS
CC = gcc
CFLAGS = -fPIC
INCDIRS += -Iinclude -I#{java_home}/include -I#{java_home}/include/solaris
@@ -73,7 +73,7 @@ EOS
if w64build?
Path.check_cmd(W64_PREFIX + 'gcc', 'make', W64_PREFIX + 'ld')
Path.check_cmd('cygpath')
- Config::CONFIG << <<EOS
+ JblasConfig::CONFIG << <<EOS
CC = #{W64_PREFIX}gcc
CFLAGS = -ggdb -D__int64='long long'
INCDIRS += -I"#{dir java_home}/include" -I"#{dir java_home}/include/win32" -Iinclude
@@ -83,36 +83,37 @@ LIB =
RUBY = ruby
EOS
else
- Path.check_cmd('gcc', 'make', 'ld')
+ Path.check_cmd(W32_PREFIX + 'gcc', 'make', W32_PREFIX + 'ld')
Path.check_cmd('cygpath')
- Config::CONFIG << <<EOS
-CC = gcc
+ JblasConfig::CONFIG << <<EOS
+CC = #{W32_PREFIX}gcc
CFLAGS = -ggdb -D__int64='long long'
INCDIRS += -I"#{dir java_home}/include" -I"#{dir java_home}/include/win32" -Iinclude
-LDFLAGS += -mno-cygwin -shared -Wl,--add-stdcall-alias
+LDFLAGS += -shared -Wl,--add-stdcall-alias
SO = dll
LIB =
RUBY = ruby
EOS
end
when 'Mac\ OS\ X'
- Path.check_cmd('gcc-mp-4.3', 'make')
- Config::CONFIG << <<EOS
-CC = gcc-mp-4.3
-LD = gcc-mp-4.3
+ #Path.check_cmd('gcc-mp-4.3', 'make')
+ Path.check_cmd('gcc', 'make')
+ JblasConfig::CONFIG << <<EOS
+CC = gcc
+LD = gcc
CFLAGS = -fPIC
-INCDIRS += -Iinclude -I#{java_home}/include
+INCDIRS += -Iinclude -I#{java_home}/include -I#{java_home}/include/darwin
SO = jnilib
LIB = lib
RUBY = ruby
LDFLAGS += -shared
EOS
else
- Config.fail "Sorry, the OS #{os_name} is currently not supported"
+ JblasConfig.fail "Sorry, the OS #{os_name} is currently not supported"
end
- if %w(i386 x86 x86_64 amd64).include? Config::CONFIG['OS_ARCH']
- Config::CONFIG['CFLAGS'] << ' -DHAS_CPUID'
+ if %w(i386 x86 x86_64 amd64).include? JblasConfig::CONFIG['OS_ARCH']
+ JblasConfig::CONFIG['CFLAGS'] << ' -DHAS_CPUID'
end
ok(CONFIG['CC'])
@@ -120,4 +121,4 @@ end
if __FILE__ == $0
ConfigureTask.run :cc
-end
\ No newline at end of file
+end
diff --git a/config/config_fortran.rb b/config/config_fortran.rb
index 69b7053..8f1fb92 100644
--- a/config/config_fortran.rb
+++ b/config/config_fortran.rb
@@ -37,7 +37,7 @@ require 'config/config'
require 'config/config_cc'
require 'config/config_os_arch'
-include Config
+include JblasConfig
include Path
configure :fortran => ['F77', 'LD']
@@ -51,12 +51,18 @@ configure 'F77', 'LD' => ['OS_NAME', 'CC'] do
if CONFIG['OS_NAME'] == 'Mac\ OS\ X'
CONFIG['LD'] = CONFIG['CC']
- CONFIG['F77'] = 'gfortran-mp-4.3'
- CONFIG['CCC'] = 'c99'
- elsif CONFIG['OS_NAME'] == 'Windows' and CONFIG['OS_ARCH'] == 'amd64'
- CONFIG['LD'] = W64_PREFIX + 'gfortran'
- CONFIG['F77'] = W64_PREFIX + 'gfortran'
+ CONFIG['F77'] = 'gfortran'
CONFIG['CCC'] = 'c99'
+ elsif CONFIG['OS_NAME'] == 'Windows'
+ if CONFIG['OS_ARCH'] == 'amd64'
+ CONFIG['LD'] = W64_PREFIX + 'gfortran'
+ CONFIG['F77'] = W64_PREFIX + 'gfortran'
+ CONFIG['CCC'] = 'c99'
+ else
+ CONFIG['LD'] = W32_PREFIX + 'gfortran'
+ CONFIG['F77'] = W32_PREFIX + 'gfortran'
+ CONFIG['CCC'] = 'c99'
+ end
else
g77 = Path.where('g77')
gfortran = Path.where('gfortran')
diff --git a/config/config_java.rb b/config/config_java.rb
index c834f1b..4e309ec 100644
--- a/config/config_java.rb
+++ b/config/config_java.rb
@@ -37,7 +37,7 @@ require 'config/config'
require 'config/path'
require 'config/config_os_arch'
-include Config
+include JblasConfig
include Path
configure :java => ['FOUND_JAVA', 'JAVA_HOME']
@@ -57,9 +57,9 @@ configure 'JAVA_HOME' => ['FOUND_JAVA', 'OS_NAME'] do
else
java_home = dir(File.dirname(%x(java -cp config PrintProperty java.home)))
end
- if CONFIG['OS_NAME'] == 'Mac\ OS\ X'
- java_home = File.join(java_home, 'Home')
- end
+ #if CONFIG['OS_NAME'] == 'Mac\ OS\ X'
+ # java_home = File.join(java_home, 'Home')
+ #end
check_files java_home, ['include', 'jni.h'] do
CONFIG['JAVA_HOME'] = java_home #.escape
end
@@ -68,4 +68,4 @@ end
if __FILE__ == $0
ConfigureTask.run :java
-end
\ No newline at end of file
+end
diff --git a/config/config_lapack_sources.rb b/config/config_lapack_sources.rb
index 37757d0..cb56085 100644
--- a/config/config_lapack_sources.rb
+++ b/config/config_lapack_sources.rb
@@ -37,7 +37,7 @@ require 'config/path'
require 'config/opts'
require 'config/string_ext'
-include Config
+include JblasConfig
configure :lapack_sources => 'LAPACK_HOME'
def check_lapack_home(lapack_home)
@@ -54,7 +54,7 @@ configure 'LAPACK_HOME' do
rescue ConfigError => e
if $opts.defined? :download_lapack
puts "trying to download lapack (about 5M)"
- print "Looking for wget..."; check_cmd 'wget'; Config.ok
+ print "Looking for wget..."; check_cmd 'wget'; JblasConfig.ok
lapack_tgz = File.join('.', 'lapack-lite-3.1.1.tgz')
File.delete(lapack_tgz) if File.exist?(lapack_tgz)
system("wget http://www.netlib.org/lapack/lapack-lite-3.1.1.tgz")
@@ -79,5 +79,5 @@ end
if __FILE__ == $0
$opts = Opts.new(ARGV)
- Config.run :lapack_sources
+ JblasConfig.run :lapack_sources
end
\ No newline at end of file
diff --git a/config/config_libs.rb b/config/config_libs.rb
index 0ca80c4..2055a49 100644
--- a/config/config_libs.rb
+++ b/config/config_libs.rb
@@ -54,7 +54,7 @@ require 'config/opts'
require 'config/config_os_arch'
require 'config/config_fortran'
-include Config
+include JblasConfig
ATLAS_REQUIRED_SYMBOLS = [
'dsyev_', # eigenvalue function not yet included in ATLAS/LAPACK
@@ -65,7 +65,7 @@ ATLAS_REQUIRED_SYMBOLS = [
'ATL_caxpy'
]
-LAPACK_REQUIRED_SYMBOLS = [ 'dsyev_', 'daxpy_' ]
+LAPACK_REQUIRED_SYMBOLS = [ 'dsyev_', 'daxpy_', 'dgemm_' ]
ATLAS_LIBS = %w(lapack lapack_fortran lapack_atlas f77blas cblas atlas)
PT_ATLAS_LIBS = %w(lapack lapack_fortran lapack_atlas ptf77blas ptcblas atlas)
@@ -113,17 +113,22 @@ end
desc 'looking for libraries...'
configure 'LOADLIBES' => ['LINKAGE_TYPE', :libpath, 'F77', 'BUILD_TYPE', 'OS_ARCH'] do
- case CONFIG['BUILD_TYPE']
- when 'atlas'
- if $opts.defined? :ptatlas
- libs = PT_ATLAS_LIBS
- else
- libs = ATLAS_LIBS
- end
- syms = ATLAS_REQUIRED_SYMBOLS
- when 'lapack'
- libs = LAPACK_LIBS
+ if $opts.defined? :libs
+ libs = $opts[:libs].split(',')
syms = LAPACK_REQUIRED_SYMBOLS
+ else
+ case CONFIG['BUILD_TYPE']
+ when 'atlas'
+ if $opts.defined? :ptatlas
+ libs = PT_ATLAS_LIBS
+ else
+ libs = ATLAS_LIBS
+ end
+ syms = ATLAS_REQUIRED_SYMBOLS
+ when 'lapack'
+ libs = LAPACK_LIBS
+ syms = LAPACK_REQUIRED_SYMBOLS
+ end
end
result = LibHelpers.find_libs(CONFIG[:libpath], libs, syms)
@@ -142,28 +147,33 @@ configure 'LOADLIBES' => ['LINKAGE_TYPE', :libpath, 'F77', 'BUILD_TYPE', 'OS_ARC
CONFIG['LOADLIBES'] += result.keys.
sort {|x, y| libs.index(x) <=> libs.index(y)}.
map {|s| File.join(result[s], LibHelpers.libname(s)).escape }
- if CONFIG['F77'] == 'gfortran'
+ if CONFIG['F77'] =~ /gfortran$/
puts CONFIG['OS_ARCH']
if CONFIG['OS_NAME'] == 'Linux' and CONFIG['OS_ARCH'] == 'amd64'
- CONFIG['LOADLIBES'] += ['-lgfortran']
+ CONFIG['LOADLIBES'] += ['-lgfortran']
puts <<EOS
WARNING: on 64bit Linux, I cannot link the gfortran library into the shared library
because it's usually not compiled with -fPIC. This means that you need to
have libgfortran.so installed on your target system. Sorry for the inconvenience!
EOS
+ elsif CONFIG['OS_NAME'] == 'Mac\ OS\ X'
+ print "Looking for where libgfortran.a is... "
+ libgfortran_path = %x(gfortran -print-file-name=libgfortran.a).strip
+ puts "(#{libgfortran_path})"
+ CONFIG['LOADLIBES'] += [libgfortran_path]
else
CONFIG['LOADLIBES'] += ['-l:libgfortran.a']
end
end
- if CONFIG['OS_NAME'] == 'Mac\ OS\ X'
- CONFIG['LOADLIBES'] += ['/opt/local/lib/gcc43/libgfortran.a']
- end
+ #if CONFIG['OS_NAME'] == 'Mac\ OS\ X'
+ # CONFIG['LOADLIBES'] += ['/opt/local/lib/gcc43/libgfortran.a']
+ #end
end
ok
end
if __FILE__ == $0
$opts = Opts.new(ARGV)
- Config.run :libs
- Config::CONFIG.dump($stdout)
+ JblasConfig.run :libs
+ JblasConfig::CONFIG.dump($stdout)
end
diff --git a/config/config_make.rb b/config/config_make.rb
index b0b1c25..151a67e 100644
--- a/config/config_make.rb
+++ b/config/config_make.rb
@@ -35,7 +35,7 @@
require 'config/config'
require 'config/path'
-include Config
+include JblasConfig
include Path
configure :make => ['MAKE']
@@ -47,7 +47,7 @@ configure 'MAKE' do
CONFIG['MAKE'] = 'gmake'
else
if Path.where_with_output('make -v', /GNU Make/).nil?
- Config.fail('I need GNU make to run...')
+ JblasConfig.fail('I need GNU make to run...')
end
CONFIG['MAKE'] = 'make'
end
diff --git a/config/config_os_arch.rb b/config/config_os_arch.rb
index 1e9e2ec..0aa044c 100644
--- a/config/config_os_arch.rb
+++ b/config/config_os_arch.rb
@@ -40,7 +40,7 @@ require 'config/config'
require 'config/string_ext'
require 'config/config_java'
-include Config
+include JblasConfig
def detect_os
os_name = %x(java -cp config PrintProperty os.name).chomp
diff --git a/config/config_tools.rb b/config/config_tools.rb
index 8f87f8a..c23e1c4 100644
--- a/config/config_tools.rb
+++ b/config/config_tools.rb
@@ -37,7 +37,7 @@ require 'config/path'
require 'config/config_os_arch'
require 'config/windows'
-include Config
+include JblasConfig
include Path
configure :tools => ['FOUND_NM', 'FOUND_CYGPATH']
diff --git a/config/configure.rb b/config/configure.rb
index 5234eab..f45fb50 100644
--- a/config/configure.rb
+++ b/config/configure.rb
@@ -32,6 +32,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## --- END LICENSE BLOCK ---
+$: << "."
+
require 'config/path'
require 'config/config'
require 'config/opts'
@@ -44,7 +46,7 @@ require 'config/config_make'
require 'config/config_lapack_sources'
require 'config/config_libs'
-include Config
+include JblasConfig
include Path
args = []
@@ -82,6 +84,7 @@ options summary:
(default for Windows!)
--ptatlas Link against multithreaded versions of ATLAS libraries
--arch-flavor=... Set architectural flavor (e.g. --arch-flavor=sse2)
+ --libs=lib1,lib2,... Override libraries to search in
EOS
configure :all => [:os_arch, :tools, :java, :cc, :fortran, :make, :lapack_sources, :libs]
diff --git a/config/lib_helpers.rb b/config/lib_helpers.rb
index cde1a7a..e8a16c1 100644
--- a/config/lib_helpers.rb
+++ b/config/lib_helpers.rb
@@ -40,11 +40,11 @@ module LibHelpers
module_function
def libname(name)
- case Config::CONFIG['LINKAGE_TYPE']
+ case JblasConfig::CONFIG['LINKAGE_TYPE']
when 'static'
'lib' + name + '.a'
when 'dynamic'
- case Config::CONFIG['OS_NAME']
+ case JblasConfig::CONFIG['OS_NAME']
when 'Linux'
'lib' + name + '.so'
when 'SunOS'
@@ -54,17 +54,17 @@ module LibHelpers
when 'Mac\ OS\ X'
'lib' + name + '.dylib'
else
- Config.fail "Sorry, OS '#{Config::CONFIG['OS_NAME']}' is not supported yet..."
+ JblasConfig.fail "Sorry, OS '#{JblasConfig::CONFIG['OS_NAME']}' is not supported yet..."
end
else
- raise "LINKAGE_TYPE should be either dynamic or static, but is #{Config::CONFIG['LINKAGE_TYPE']}"
+ raise "LINKAGE_TYPE should be either dynamic or static, but is #{JblasConfig::CONFIG['LINKAGE_TYPE']}"
end
end
# returns an array of the symbols defined in the library +fn+.
def libsyms(fn)
nmopt = File.extname(fn) == '.so' ? '-D' : ''
- %x(#{Config::CONFIG['NM']} -p #{nmopt} #{fn.escape}).grep(/ T _?([a-zA-Z0-9_]+)/) {|m| $1}
+ %x(#{JblasConfig::CONFIG['NM']} -p #{nmopt} #{fn.escape}).split("\n").grep(/ T _?([a-zA-Z0-9_]+)/) {|m| $1}
end
def locate_lib(libpath, name, symbol=nil)
@@ -73,10 +73,10 @@ module LibHelpers
end
if not p
- Config.fail("couldn't find library '#{name}' in\npath #{libpath.join ':'}")
+ JblasConfig.fail("couldn't find library '#{name}' in\npath #{libpath.join ':'}")
end
- Config.log "found library #{name} in #{p}"
+ JblasConfig.log "found library #{name} in #{p}"
return p
end
@@ -86,26 +86,26 @@ module LibHelpers
def locate_one_of_libs(libpath, names, symbol=nil)
p = nil
l = nil
- Config.log "Searching for one of #{names.join ', '} in #{libpath.join ':'}#{if symbol then ' having symbol ' + symbol.to_s end}"
+ JblasConfig.log "Searching for one of #{names.join ', '} in #{libpath.join ':'}#{if symbol then ' having symbol ' + symbol.to_s end}"
for name in names
- Config.log " Searching for #{libname(name)}"
+ JblasConfig.log " Searching for #{libname(name)}"
p = Path.where(libname(name), libpath) do |fn|
symbol.nil? or libsyms(fn).include? symbol
end
if p
l = name
- Config.log "Found at #{l} at #{p}"
+ JblasConfig.log "Found at #{l} at #{p}"
break
end
end
if not p
- Config.log "Haven't found any of #{names.join ', '}!"
- Config.fail("couldn't find library '#{name}' in\npath #{LIBPATH.join ':'}")
+ JblasConfig.log "Haven't found any of #{names.join ', '}!"
+ JblasConfig.fail("couldn't find library '#{name}' in\npath #{LIBPATH.join ':'}")
end
- Config.log "found library #{l} in #{p}"
+ JblasConfig.log "found library #{l} in #{p}"
return p, l
end
@@ -133,7 +133,7 @@ module LibHelpers
not_found_symbols = symbols.reject {|s| found_symbols.include? s }
unless not_found_symbols.empty?
- Config.fail "Could not locate libraries for the following symbols: #{not_found_symbols.join ', '}."
+ JblasConfig.fail "Could not locate libraries for the following symbols: #{not_found_symbols.join ', '}."
end
#found_symbols.each_pair {|k,v| printf "%20s: %s\n", k, v.inspect}
@@ -160,7 +160,7 @@ if __FILE__ == $0
libs = %w(atlas lapack blas f77blas cblas lapack_atlas)
- Config::CONFIG['BUILD_TYPE'] = 'static'
- Config::CONFIG['OS_NAME'] = 'Linux'
+ JblasConfig::CONFIG['BUILD_TYPE'] = 'static'
+ JblasConfig::CONFIG['OS_NAME'] = 'Linux'
p find_libs(paths, libs, symbols_needed)
end
\ No newline at end of file
diff --git a/config/path.rb b/config/path.rb
index 62695fd..00aaf2c 100644
--- a/config/path.rb
+++ b/config/path.rb
@@ -76,8 +76,8 @@ module Path
# Check whether a cmd could be found.
def check_cmd(*cmds)
cmds.each do |cmd|
- Config.log "Searching for command #{cmd}"
- Config.fail("coulnd't find command #{cmd}") unless Path.where cmd
+ JblasConfig.log "Searching for command #{cmd}"
+ JblasConfig.fail("coulnd't find command #{cmd}") unless Path.where cmd
end
yield self if block_given?
return
@@ -87,8 +87,8 @@ module Path
def check_files(path, *files)
files.each do |file|
file = File.join(path, *file)
- Config.log "Searching for file #{file}"
- Config.fail("couldn't find #{file}") unless File.exist? file
+ JblasConfig.log "Searching for file #{file}"
+ JblasConfig.fail("couldn't find #{file}") unless File.exist? file
end
yield if block_given?
return
@@ -96,7 +96,7 @@ module Path
# translate dir (mainly necessary for cygwin)
def dir(s)
- case Config::CONFIG['OS_NAME']
+ case JblasConfig::CONFIG['OS_NAME']
when 'Windows'
s = s.gsub(/\\/, '\\\\\\\\')
%x(cygpath -u '#{s}').chomp
diff --git a/config/windows.rb b/config/windows.rb
index 999e54f..3602114 100755
--- a/config/windows.rb
+++ b/config/windows.rb
@@ -35,7 +35,8 @@
require 'config/config'
W64_PREFIX = 'x86_64-w64-mingw32-'
+W32_PREFIX = 'i686-pc-mingw32-'
def w64build?
- Config::CONFIG['OS_NAME'] == 'Windows' and Config::CONFIG['OS_ARCH'] == 'amd64'
+ JblasConfig::CONFIG['OS_NAME'] == 'Windows' and JblasConfig::CONFIG['OS_ARCH'] == 'amd64'
end
diff --git a/fortranwrapper.dump b/fortranwrapper.dump
index 5444084..13d94b7 100644
Binary files a/fortranwrapper.dump and b/fortranwrapper.dump differ
diff --git a/pom.xml b/pom.xml
index 62c43da..81b86e3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1,67 +1,136 @@
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
- <groupId>org.jblas</groupId>
- <artifactId>jblas</artifactId>
- <version>1.2.0</version>
- <packaging>jar</packaging>
+ <groupId>org.jblas</groupId>
+ <artifactId>jblas</artifactId>
+ <version>1.2.3</version>
+ <packaging>jar</packaging>
- <name>jblas</name>
- <url>http://maven.apache.org</url>
+ <name>jblas</name>
+ <description>A fast linear algebra library for Java.</description>
+ <url>http://jblas.org/</url>
+ <licenses>
+ <license>
+ <name>BSD 3-clause style license</name>
+ <url>http://opensource.org/licenses/BSD-3-Clause</url>
+ <distribution>repo</distribution>
+ </license>
+ </licenses>
- <properties>
- <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
- </properties>
+ <scm>
+ <connection>scm:git:https://github.com/mikiobraun/jblas.git</connection>
+ <developerConnection>scm:git:https://github.com/mikiobraun/jblas.git</developerConnection>
+ <url>https://github.com/mikiobraun/jblas.git</url>
+ </scm>
- <build>
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-compiler-plugin</artifactId>
- <configuration>
- <source>1.5</source>
- <target>1.5</target>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-jar-plugin</artifactId>
- <configuration>
- <archive>
- <manifest>
- <mainClass>org.jblas.benchmark.Main</mainClass>
- </manifest>
- </archive>
- </configuration>
- </plugin>
- <plugin>
- <artifactId>maven-antrun-plugin</artifactId>
- <executions>
- <execution>
- <id>generate-float-sources</id>
- <phase>generate-sources</phase>
- <goals>
- <goal>run</goal>
- </goals>
- <configuration>
- <tasks>
- <ant target="generate-float" />
- </tasks>
- </configuration>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
+ <developers>
+ <developer>
+ <id>mikiobraun</id>
+ <name>Mikio L. Braun</name>
+ <email>mikiobraun at gmail.com</email>
+ </developer>
+ </developers>
+ <parent>
+ <groupId>org.sonatype.oss</groupId>
+ <artifactId>oss-parent</artifactId>
+ <version>7</version>
+ </parent>
- <dependencies>
- <dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
- <version>3.8.1</version>
- <scope>test</scope>
- </dependency>
- </dependencies>
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>4.10</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <!-- the usual config of the maven compiler plugin -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>2.3.2</version>
+ <configuration>
+ <source>1.5</source>
+ <target>1.5</target>
+ </configuration>
+ </plugin>
+
+ <!-- some config for the jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>2.3.1</version>
+ <configuration>
+ <archive>
+ <manifest>
+ <mainClass>org.jblas.benchmark.Main</mainClass>
+ </manifest>
+ </archive>
+ </configuration>
+ </plugin>
+
+ <!-- we need to call some ant targets to generate automatic sources -->
+ <plugin>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>generate-float-sources</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <tasks>
+ <ant target="generate-float" />
+ </tasks>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+
+ <!-- javadoc jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9</version>
+ <executions>
+ <execution>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+
+ <!-- sources.jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <configuration>
+ <excludeResources>true</excludeResources>
+ <excludes>
+ <exclude>**/*.rjpp</exclude>
+ <exclude>*.html</exclude>
+ <exclude>*.textile</exclude>
+ </excludes>
+ </configuration>
+ <executions>
+ <execution>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
</project>
diff --git a/scripts/class_to_float.rb b/scripts/class_to_float.rb
index e37af2a..4c5415a 100644
--- a/scripts/class_to_float.rb
+++ b/scripts/class_to_float.rb
@@ -67,7 +67,25 @@ def translate(s)
s.gsub! /<Double>/, '<Float>'
s.gsub! /readDouble/, 'readFloat'
s.gsub! /java.lang.Double/, 'java.lang.Float'
- s
+
+ # go through lines. If a line starts with "//FLOAT//" replace the next line by the following line
+ discard_lines = 0
+ result = []
+ s.split("\n").each do |line|
+ if discard_lines > 0
+ discard_lines -= 1
+ else
+ i = line.index('//FLOAT//')
+ if i
+ discard_lines = 1
+ result <<= line[(i + "//FLOAT//".length)..-1]
+ else
+ result <<= line
+ end
+ end
+ end
+
+ result.join("\n")
end
if ARGV.size < 2
diff --git a/scripts/fortran/java.rb b/scripts/fortran/java.rb
index be7955a..8221f01 100644
--- a/scripts/fortran/java.rb
+++ b/scripts/fortran/java.rb
@@ -377,6 +377,8 @@ EOS
end
elsif javatype == 'char'
Java::CharArgument.new(self)
+ elsif javatype == 'String'
+ Java::StringArgument.new(self)
else
Java::GenericArgument.new(self)
end
@@ -712,5 +714,20 @@ EOS
code.fortran_args << ctype[1...-5] + " *"
end
end
+
+ class StringArgument < GenericArgument
+ def make_fortran_arg
+ code.fortran_args << 'char *'
+ end
+
+ def make_convert_arg
+ code.conversions << " char *#{name}Str = (*env)->GetStringChars(env, #{name}, NULL);\n"
+ code.release_arrays << " (*env)->ReleaseStringChars(env, #{name}, #{name}Str);\n"
+ end
+
+ def make_call_arg
+ code.call_args << "#{name}Str"
+ end
+ end
end # module Java
end # module Fortran
diff --git a/scripts/fortran/parser.rb b/scripts/fortran/parser.rb
index ea24024..bf2f2e7 100644
--- a/scripts/fortran/parser.rb
+++ b/scripts/fortran/parser.rb
@@ -166,6 +166,7 @@ module Fortran
puts "#$1 -> #$2" if $debug
type = $1
args = $2.scan ArgumentParens
+ type = type.sub(/\(\s+/, "(").sub(/\s+\)/, ")")
args.each do |argname|
puts " #{argname} -> #{type}" if $debug
if argname =~ /([A-Z0-9]+)\ *\(.*\)/
diff --git a/scripts/fortran/types.rb b/scripts/fortran/types.rb
index e3167f2..f94ece6 100644
--- a/scripts/fortran/types.rb
+++ b/scripts/fortran/types.rb
@@ -53,7 +53,8 @@ module Fortran
'DOUBLE PRECISION' => 'REAL*8',
'INTEGER' => 'INTEGER*4',
'LOGICAL' => 'LOGICAL*4',
- 'REAL' => 'REAL*4' }
+ 'REAL' => 'REAL*4',
+ 'CHARACTER*(*)' => 'CHARACTER*N' }
# Map fortran types to a unique name (via DefaultTypes).
def self.standardize_type(name)
diff --git a/scripts/java-class.java b/scripts/java-class.java
index dd3950d..0be0fdd 100644
--- a/scripts/java-class.java
+++ b/scripts/java-class.java
@@ -74,21 +74,13 @@ import org.jblas.util.Logger;
public class <%= classname %> {
static {
- try {
- System.loadLibrary("jblas");
- } catch (UnsatisfiedLinkError e) {
- Logger.getLogger().config(
- "BLAS native library not found in path. Copying native library "
- + "from the archive. Consider installing the library somewhere "
- + "in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH).");
- new org.jblas.util.LibraryLoader().loadLibrary("jblas", true);
- }
- }
- private static int[] intDummy = new int[1];
- private static double[] doubleDummy = new double[1];
- private static float[] floatDummy = new float[1];
+ NativeBlasLibraryLoader.loadLibraryAndCheckErrors();
+ }
+
+ private static int[] intDummy = new int[1];
+ private static double[] doubleDummy = new double[1];
+ private static float[] floatDummy = new float[1];
-
<% for r in routines -%>
<%= generate_native_declaration r %>
<% end %>
diff --git a/scripts/java-impl.c b/scripts/java-impl.c
index 74b6ec4..fcf49c3 100644
--- a/scripts/java-impl.c
+++ b/scripts/java-impl.c
@@ -60,6 +60,8 @@ static jobject createObject(JNIEnv *env, const char *className, const char *sign
va_start(args, signature);
newObject = (*env)->NewObjectV(env, klass, init, args);
va_end(args);
+
+ return newObject;
}
<% if $complexcc == 'f2c' %>
diff --git a/scripts/rjpp.rb b/scripts/rjpp.rb
index 465c8cd..822da20 100644
--- a/scripts/rjpp.rb
+++ b/scripts/rjpp.rb
@@ -53,11 +53,15 @@
#
# So you can run the rjpp twice on a file and get the same result as
# running it once. This property is called idempotency in mathematics X-D
+#
+# New feature:
+#
+# /*
# print usage
if ARGV.length == 0
- puts "Usage: jrpp file"
+ puts "Usage: rjpp file"
end
def collect(*args)
@@ -79,10 +83,15 @@ file.gsub! /\/\/RJPP-BEGIN.*?\/\/RJPP-END[^\n]*\n/m, ''
# expand code
file.gsub! /\/\*\#(.*?)\#\*\//m do |s|
- expansion = eval($1).to_s
result = s
- unless expansion.empty?
- result << "\n//RJPP-BEGIN------------------------------------------------------------\n" + expansion + "//RJPP-END--------------------------------------------------------------"
+ expansion = eval($1)
+ if expansion
+ if Array === expansion
+ expansion = expansion.join
+ end
+ unless expansion.empty?
+ result << ("\n//RJPP-BEGIN------------------------------------------------------------\n" + expansion + "//RJPP-END--------------------------------------------------------------")
+ end
end
result
end
diff --git a/src/main/c/NativeBlas.c b/src/main/c/NativeBlas.c
index 71b814e..3a38322 100644
--- a/src/main/c/NativeBlas.c
+++ b/src/main/c/NativeBlas.c
@@ -60,6 +60,8 @@ static jobject createObject(JNIEnv *env, const char *className, const char *sign
va_start(args, signature);
newObject = (*env)->NewObjectV(env, klass, init, args);
va_end(args);
+
+ return newObject;
}
@@ -107,7 +109,7 @@ static void throwIllegalArgumentException(JNIEnv *env, const char *message)
/**********************************************************************/
static char *routine_names[] = {
- "CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGEMM", "DGEMV", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCASUM", "SCNRM2", "SCOPY", "SDOT", "SGEEV" [...]
+ "CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGELSD", "DGEMM", "DGEMV", "DGEQRF", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DORMQR", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ILAENV", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCA [...]
};
static char *routine_arguments[][21] = {
@@ -129,13 +131,16 @@ static char *routine_arguments[][21] = {
{ "N", "DX", "INCX", "DY", "INCY" },
{ "N", "DX", "INCX", "DY", "INCY" },
{ "JOBVL", "JOBVR", "N", "A", "LDA", "WR", "WI", "VL", "LDVL", "VR", "LDVR", "WORK", "LWORK", "INFO" },
+ { "M", "N", "NRHS", "A", "LDA", "B", "LDB", "S", "RCOND", "RANK", "WORK", "LWORK", "IWORK", "INFO" },
{ "TRANSA", "TRANSB", "M", "N", "K", "ALPHA", "A", "LDA", "B", "LDB", "BETA", "C", "LDC" },
{ "TRANS", "M", "N", "ALPHA", "A", "LDA", "X", "INCX", "BETA", "Y", "INCY" },
+ { "M", "N", "A", "LDA", "TAU", "WORK", "LWORK", "INFO" },
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "N", "NRHS", "A", "LDA", "IPIV", "B", "LDB", "INFO" },
{ "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "INFO" },
{ "M", "N", "A", "LDA", "IPIV", "INFO" },
{ "N", "X", "INCX" },
+ { "SIDE", "TRANS", "M", "N", "K", "A", "LDA", "TAU", "C", "LDC", "WORK", "LWORK", "INFO" },
{ "UPLO", "N", "NRHS", "A", "LDA", "B", "LDB", "INFO" },
{ "UPLO", "N", "A", "LDA", "INFO" },
{ "N", "DA", "DX", "INCX" },
@@ -150,6 +155,7 @@ static char *routine_arguments[][21] = {
{ "N", "X", "INCX" },
{ "N", "CX", "INCX" },
{ "N", "DX", "INCX" },
+ { "ISPEC", "NAME", "OPTS", "N1", "N2", "N3", "N4" },
{ "N", "SX", "INCX" },
{ "N", "ZX", "INCX" },
{ "N", "SX", "INCX" },
@@ -159,13 +165,16 @@ static char *routine_arguments[][21] = {
{ "N", "SX", "INCX", "SY", "INCY" },
{ "N", "SX", "INCX", "SY", "INCY" },
{ "JOBVL", "JOBVR", "N", "A", "LDA", "WR", "WI", "VL", "LDVL", "VR", "LDVR", "WORK", "LWORK", "INFO" },
+ { "M", "N", "NRHS", "A", "LDA", "B", "LDB", "S", "RCOND", "RANK", "WORK", "LWORK", "IWORK", "INFO" },
{ "TRANSA", "TRANSB", "M", "N", "K", "ALPHA", "A", "LDA", "B", "LDB", "BETA", "C", "LDC" },
{ "TRANS", "M", "N", "ALPHA", "A", "LDA", "X", "INCX", "BETA", "Y", "INCY" },
+ { "M", "N", "A", "LDA", "TAU", "WORK", "LWORK", "INFO" },
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "N", "NRHS", "A", "LDA", "IPIV", "B", "LDB", "INFO" },
{ "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "INFO" },
{ "M", "N", "A", "LDA", "IPIV", "INFO" },
{ "N", "X", "INCX" },
+ { "SIDE", "TRANS", "M", "N", "K", "A", "LDA", "TAU", "C", "LDC", "WORK", "LWORK", "INFO" },
{ "UPLO", "N", "NRHS", "A", "LDA", "B", "LDB", "INFO" },
{ "UPLO", "N", "A", "LDA", "INFO" },
{ "N", "SA", "SX", "INCX" },
@@ -4338,3 +4347,488 @@ JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ssygvd(JNIEnv *env, jclass this
return info;
}
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dgelsd(JNIEnv *env, jclass this, jint m, jint n, jint nrhs, jdoubleArray a, jint aIdx, jint lda, jdoubleArray b, jint bIdx, jint ldb, jdoubleArray s, jint sIdx, jdouble rcond, jintArray rank, jint rankIdx, jdoubleArray work, jint workIdx, jint lwork, jintArray iwork, jint iworkIdx)
+{
+ extern void dgelsd_(jint *, jint *, jint *, jdouble *, jint *, jdouble *, jint *, jdouble *, jdouble *, jint *, jdouble *, jint *, jint *, int *);
+
+ jdouble *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetDoubleArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jint *iworkPtrBase = 0, *iworkPtr = 0;
+ if (iwork) {
+ iworkPtrBase = (*env)->GetIntArrayElements(env, iwork, NULL);
+ iworkPtr = iworkPtrBase + iworkIdx;
+ }
+ jdouble *bPtrBase = 0, *bPtr = 0;
+ if (b) {
+ if((*env)->IsSameObject(env, b, a) == JNI_TRUE)
+ bPtrBase = aPtrBase;
+ else
+ bPtrBase = (*env)->GetDoubleArrayElements(env, b, NULL);
+ bPtr = bPtrBase + bIdx;
+ }
+ jdouble *sPtrBase = 0, *sPtr = 0;
+ if (s) {
+ if((*env)->IsSameObject(env, s, a) == JNI_TRUE)
+ sPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, s, b) == JNI_TRUE)
+ sPtrBase = bPtrBase;
+ else
+ sPtrBase = (*env)->GetDoubleArrayElements(env, s, NULL);
+ sPtr = sPtrBase + sIdx;
+ }
+ jint *rankPtrBase = 0, *rankPtr = 0;
+ if (rank) {
+ if((*env)->IsSameObject(env, rank, iwork) == JNI_TRUE)
+ rankPtrBase = iworkPtrBase;
+ else
+ rankPtrBase = (*env)->GetIntArrayElements(env, rank, NULL);
+ rankPtr = rankPtrBase + rankIdx;
+ }
+ jdouble *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, b) == JNI_TRUE)
+ workPtrBase = bPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, s) == JNI_TRUE)
+ workPtrBase = sPtrBase;
+ else
+ workPtrBase = (*env)->GetDoubleArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ dgelsd_(&m, &n, &nrhs, aPtr, &lda, bPtr, &ldb, sPtr, &rcond, rankPtr, workPtr, &lwork, iworkPtr, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == bPtrBase)
+ bPtrBase = 0;
+ if (workPtrBase == sPtrBase)
+ sPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(rankPtrBase) {
+ (*env)->ReleaseIntArrayElements(env, rank, rankPtrBase, 0);
+ if (rankPtrBase == iworkPtrBase)
+ iworkPtrBase = 0;
+ rankPtrBase = 0;
+ }
+ if(sPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, s, sPtrBase, 0);
+ if (sPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (sPtrBase == bPtrBase)
+ bPtrBase = 0;
+ sPtrBase = 0;
+ }
+ if(bPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, b, bPtrBase, 0);
+ if (bPtrBase == aPtrBase)
+ aPtrBase = 0;
+ bPtrBase = 0;
+ }
+ if(iworkPtrBase) {
+ (*env)->ReleaseIntArrayElements(env, iwork, iworkPtrBase, JNI_ABORT);
+ iworkPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, a, aPtrBase, JNI_ABORT);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sgelsd(JNIEnv *env, jclass this, jint m, jint n, jint nrhs, jfloatArray a, jint aIdx, jint lda, jfloatArray b, jint bIdx, jint ldb, jfloatArray s, jint sIdx, jfloat rcond, jintArray rank, jint rankIdx, jfloatArray work, jint workIdx, jint lwork, jintArray iwork, jint iworkIdx)
+{
+ extern void sgelsd_(jint *, jint *, jint *, jfloat *, jint *, jfloat *, jint *, jfloat *, jfloat *, jint *, jfloat *, jint *, jint *, int *);
+
+ jfloat *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetFloatArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jint *iworkPtrBase = 0, *iworkPtr = 0;
+ if (iwork) {
+ iworkPtrBase = (*env)->GetIntArrayElements(env, iwork, NULL);
+ iworkPtr = iworkPtrBase + iworkIdx;
+ }
+ jfloat *bPtrBase = 0, *bPtr = 0;
+ if (b) {
+ if((*env)->IsSameObject(env, b, a) == JNI_TRUE)
+ bPtrBase = aPtrBase;
+ else
+ bPtrBase = (*env)->GetFloatArrayElements(env, b, NULL);
+ bPtr = bPtrBase + bIdx;
+ }
+ jfloat *sPtrBase = 0, *sPtr = 0;
+ if (s) {
+ if((*env)->IsSameObject(env, s, a) == JNI_TRUE)
+ sPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, s, b) == JNI_TRUE)
+ sPtrBase = bPtrBase;
+ else
+ sPtrBase = (*env)->GetFloatArrayElements(env, s, NULL);
+ sPtr = sPtrBase + sIdx;
+ }
+ jint *rankPtrBase = 0, *rankPtr = 0;
+ if (rank) {
+ if((*env)->IsSameObject(env, rank, iwork) == JNI_TRUE)
+ rankPtrBase = iworkPtrBase;
+ else
+ rankPtrBase = (*env)->GetIntArrayElements(env, rank, NULL);
+ rankPtr = rankPtrBase + rankIdx;
+ }
+ jfloat *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, b) == JNI_TRUE)
+ workPtrBase = bPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, s) == JNI_TRUE)
+ workPtrBase = sPtrBase;
+ else
+ workPtrBase = (*env)->GetFloatArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ sgelsd_(&m, &n, &nrhs, aPtr, &lda, bPtr, &ldb, sPtr, &rcond, rankPtr, workPtr, &lwork, iworkPtr, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == bPtrBase)
+ bPtrBase = 0;
+ if (workPtrBase == sPtrBase)
+ sPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(rankPtrBase) {
+ (*env)->ReleaseIntArrayElements(env, rank, rankPtrBase, 0);
+ if (rankPtrBase == iworkPtrBase)
+ iworkPtrBase = 0;
+ rankPtrBase = 0;
+ }
+ if(sPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, s, sPtrBase, 0);
+ if (sPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (sPtrBase == bPtrBase)
+ bPtrBase = 0;
+ sPtrBase = 0;
+ }
+ if(bPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, b, bPtrBase, 0);
+ if (bPtrBase == aPtrBase)
+ aPtrBase = 0;
+ bPtrBase = 0;
+ }
+ if(iworkPtrBase) {
+ (*env)->ReleaseIntArrayElements(env, iwork, iworkPtrBase, JNI_ABORT);
+ iworkPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, a, aPtrBase, JNI_ABORT);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ilaenv(JNIEnv *env, jclass this, jint ispec, jstring name, jstring opts, jint n1, jint n2, jint n3, jint n4)
+{
+ extern jint ilaenv_(jint *, char *, char *, jint *, jint *, jint *, jint *);
+
+ char *nameStr = (*env)->GetStringChars(env, name, NULL);
+ char *optsStr = (*env)->GetStringChars(env, opts, NULL);
+
+ savedEnv = env;
+ jint retval = ilaenv_(&ispec, nameStr, optsStr, &n1, &n2, &n3, &n4);
+ (*env)->ReleaseStringChars(env, name, nameStr);
+ (*env)->ReleaseStringChars(env, opts, optsStr);
+
+ return retval;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dgeqrf(JNIEnv *env, jclass this, jint m, jint n, jdoubleArray a, jint aIdx, jint lda, jdoubleArray tau, jint tauIdx, jdoubleArray work, jint workIdx, jint lwork)
+{
+ extern void dgeqrf_(jint *, jint *, jdouble *, jint *, jdouble *, jdouble *, jint *, int *);
+
+ jdouble *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetDoubleArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jdouble *tauPtrBase = 0, *tauPtr = 0;
+ if (tau) {
+ if((*env)->IsSameObject(env, tau, a) == JNI_TRUE)
+ tauPtrBase = aPtrBase;
+ else
+ tauPtrBase = (*env)->GetDoubleArrayElements(env, tau, NULL);
+ tauPtr = tauPtrBase + tauIdx;
+ }
+ jdouble *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, tau) == JNI_TRUE)
+ workPtrBase = tauPtrBase;
+ else
+ workPtrBase = (*env)->GetDoubleArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ dgeqrf_(&m, &n, aPtr, &lda, tauPtr, workPtr, &lwork, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(tauPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, tau, tauPtrBase, 0);
+ if (tauPtrBase == aPtrBase)
+ aPtrBase = 0;
+ tauPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, a, aPtrBase, 0);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sgeqrf(JNIEnv *env, jclass this, jint m, jint n, jfloatArray a, jint aIdx, jint lda, jfloatArray tau, jint tauIdx, jfloatArray work, jint workIdx, jint lwork)
+{
+ extern void sgeqrf_(jint *, jint *, jfloat *, jint *, jfloat *, jfloat *, jint *, int *);
+
+ jfloat *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetFloatArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jfloat *tauPtrBase = 0, *tauPtr = 0;
+ if (tau) {
+ if((*env)->IsSameObject(env, tau, a) == JNI_TRUE)
+ tauPtrBase = aPtrBase;
+ else
+ tauPtrBase = (*env)->GetFloatArrayElements(env, tau, NULL);
+ tauPtr = tauPtrBase + tauIdx;
+ }
+ jfloat *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, tau) == JNI_TRUE)
+ workPtrBase = tauPtrBase;
+ else
+ workPtrBase = (*env)->GetFloatArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ sgeqrf_(&m, &n, aPtr, &lda, tauPtr, workPtr, &lwork, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(tauPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, tau, tauPtrBase, 0);
+ if (tauPtrBase == aPtrBase)
+ aPtrBase = 0;
+ tauPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, a, aPtrBase, 0);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dormqr(JNIEnv *env, jclass this, jchar side, jchar trans, jint m, jint n, jint k, jdoubleArray a, jint aIdx, jint lda, jdoubleArray tau, jint tauIdx, jdoubleArray c, jint cIdx, jint ldc, jdoubleArray work, jint workIdx, jint lwork)
+{
+ extern void dormqr_(char *, char *, jint *, jint *, jint *, jdouble *, jint *, jdouble *, jdouble *, jint *, jdouble *, jint *, int *);
+
+ char sideChr = (char) side;
+ char transChr = (char) trans;
+ jdouble *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetDoubleArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jdouble *tauPtrBase = 0, *tauPtr = 0;
+ if (tau) {
+ if((*env)->IsSameObject(env, tau, a) == JNI_TRUE)
+ tauPtrBase = aPtrBase;
+ else
+ tauPtrBase = (*env)->GetDoubleArrayElements(env, tau, NULL);
+ tauPtr = tauPtrBase + tauIdx;
+ }
+ jdouble *cPtrBase = 0, *cPtr = 0;
+ if (c) {
+ if((*env)->IsSameObject(env, c, a) == JNI_TRUE)
+ cPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, c, tau) == JNI_TRUE)
+ cPtrBase = tauPtrBase;
+ else
+ cPtrBase = (*env)->GetDoubleArrayElements(env, c, NULL);
+ cPtr = cPtrBase + cIdx;
+ }
+ jdouble *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, tau) == JNI_TRUE)
+ workPtrBase = tauPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, c) == JNI_TRUE)
+ workPtrBase = cPtrBase;
+ else
+ workPtrBase = (*env)->GetDoubleArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ dormqr_(&sideChr, &transChr, &m, &n, &k, aPtr, &lda, tauPtr, cPtr, &ldc, workPtr, &lwork, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ if (workPtrBase == cPtrBase)
+ cPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(cPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, c, cPtrBase, 0);
+ if (cPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (cPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ cPtrBase = 0;
+ }
+ if(tauPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, tau, tauPtrBase, JNI_ABORT);
+ if (tauPtrBase == aPtrBase)
+ aPtrBase = 0;
+ tauPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseDoubleArrayElements(env, a, aPtrBase, JNI_ABORT);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sormqr(JNIEnv *env, jclass this, jchar side, jchar trans, jint m, jint n, jint k, jfloatArray a, jint aIdx, jint lda, jfloatArray tau, jint tauIdx, jfloatArray c, jint cIdx, jint ldc, jfloatArray work, jint workIdx, jint lwork)
+{
+ extern void sormqr_(char *, char *, jint *, jint *, jint *, jfloat *, jint *, jfloat *, jfloat *, jint *, jfloat *, jint *, int *);
+
+ char sideChr = (char) side;
+ char transChr = (char) trans;
+ jfloat *aPtrBase = 0, *aPtr = 0;
+ if (a) {
+ aPtrBase = (*env)->GetFloatArrayElements(env, a, NULL);
+ aPtr = aPtrBase + aIdx;
+ }
+ jfloat *tauPtrBase = 0, *tauPtr = 0;
+ if (tau) {
+ if((*env)->IsSameObject(env, tau, a) == JNI_TRUE)
+ tauPtrBase = aPtrBase;
+ else
+ tauPtrBase = (*env)->GetFloatArrayElements(env, tau, NULL);
+ tauPtr = tauPtrBase + tauIdx;
+ }
+ jfloat *cPtrBase = 0, *cPtr = 0;
+ if (c) {
+ if((*env)->IsSameObject(env, c, a) == JNI_TRUE)
+ cPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, c, tau) == JNI_TRUE)
+ cPtrBase = tauPtrBase;
+ else
+ cPtrBase = (*env)->GetFloatArrayElements(env, c, NULL);
+ cPtr = cPtrBase + cIdx;
+ }
+ jfloat *workPtrBase = 0, *workPtr = 0;
+ if (work) {
+ if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
+ workPtrBase = aPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, tau) == JNI_TRUE)
+ workPtrBase = tauPtrBase;
+ else
+ if((*env)->IsSameObject(env, work, c) == JNI_TRUE)
+ workPtrBase = cPtrBase;
+ else
+ workPtrBase = (*env)->GetFloatArrayElements(env, work, NULL);
+ workPtr = workPtrBase + workIdx;
+ }
+ int info;
+
+ savedEnv = env;
+ sormqr_(&sideChr, &transChr, &m, &n, &k, aPtr, &lda, tauPtr, cPtr, &ldc, workPtr, &lwork, &info);
+ if(workPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, work, workPtrBase, 0);
+ if (workPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (workPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ if (workPtrBase == cPtrBase)
+ cPtrBase = 0;
+ workPtrBase = 0;
+ }
+ if(cPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, c, cPtrBase, 0);
+ if (cPtrBase == aPtrBase)
+ aPtrBase = 0;
+ if (cPtrBase == tauPtrBase)
+ tauPtrBase = 0;
+ cPtrBase = 0;
+ }
+ if(tauPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, tau, tauPtrBase, JNI_ABORT);
+ if (tauPtrBase == aPtrBase)
+ aPtrBase = 0;
+ tauPtrBase = 0;
+ }
+ if(aPtrBase) {
+ (*env)->ReleaseFloatArrayElements(env, a, aPtrBase, JNI_ABORT);
+ aPtrBase = 0;
+ }
+
+ return info;
+}
+
diff --git a/src/main/c/jblas_arch_flavor.c b/src/main/c/jblas_arch_flavor.c
index 3a51839..aa85196 100644
--- a/src/main/c/jblas_arch_flavor.c
+++ b/src/main/c/jblas_arch_flavor.c
@@ -2,6 +2,8 @@
/* detecting sse level */
/**********************************************************************/
+#include <stdio.h>
+#include <assert.h>
#include "org_jblas_util_ArchFlavor.h"
/*
diff --git a/src/main/c/org_jblas_NativeBlas.h b/src/main/c/org_jblas_NativeBlas.h
index 011a474..92333b5 100644
--- a/src/main/c/org_jblas_NativeBlas.h
+++ b/src/main/c/org_jblas_NativeBlas.h
@@ -7,9 +7,6 @@
#ifdef __cplusplus
extern "C" {
#endif
-/* Inaccessible static: intDummy */
-/* Inaccessible static: doubleDummy */
-/* Inaccessible static: floatDummy */
/*
* Class: org_jblas_NativeBlas
* Method: ccopy
@@ -634,6 +631,62 @@ JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dsygvd
JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ssygvd
(JNIEnv *, jclass, jint, jchar, jchar, jint, jfloatArray, jint, jint, jfloatArray, jint, jint, jfloatArray, jint, jfloatArray, jint, jint, jintArray, jint, jint);
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: dgelsd
+ * Signature: (III[DII[DII[DID[II[DII[II)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dgelsd
+ (JNIEnv *, jclass, jint, jint, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jdouble, jintArray, jint, jdoubleArray, jint, jint, jintArray, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: sgelsd
+ * Signature: (III[FII[FII[FIF[II[FII[II)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sgelsd
+ (JNIEnv *, jclass, jint, jint, jint, jfloatArray, jint, jint, jfloatArray, jint, jint, jfloatArray, jint, jfloat, jintArray, jint, jfloatArray, jint, jint, jintArray, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: ilaenv
+ * Signature: (ILjava/lang/String;Ljava/lang/String;IIII)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ilaenv
+ (JNIEnv *, jclass, jint, jstring, jstring, jint, jint, jint, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: dgeqrf
+ * Signature: (II[DII[DI[DII)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dgeqrf
+ (JNIEnv *, jclass, jint, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jdoubleArray, jint, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: sgeqrf
+ * Signature: (II[FII[FI[FII)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sgeqrf
+ (JNIEnv *, jclass, jint, jint, jfloatArray, jint, jint, jfloatArray, jint, jfloatArray, jint, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: dormqr
+ * Signature: (CCIII[DII[DI[DII[DII)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_dormqr
+ (JNIEnv *, jclass, jchar, jchar, jint, jint, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jdoubleArray, jint, jint, jdoubleArray, jint, jint);
+
+/*
+ * Class: org_jblas_NativeBlas
+ * Method: sormqr
+ * Signature: (CCIII[FII[FI[FII[FII)I
+ */
+JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_sormqr
+ (JNIEnv *, jclass, jchar, jchar, jint, jint, jint, jfloatArray, jint, jint, jfloatArray, jint, jfloatArray, jint, jint, jfloatArray, jint, jint);
+
#ifdef __cplusplus
}
#endif
diff --git a/src/main/c/org_jblas_util_ArchFlavor.h b/src/main/c/org_jblas_util_ArchFlavor.h
index 3b223b7..11b95d0 100644
--- a/src/main/c/org_jblas_util_ArchFlavor.h
+++ b/src/main/c/org_jblas_util_ArchFlavor.h
@@ -7,7 +7,6 @@
#ifdef __cplusplus
extern "C" {
#endif
-/* Inaccessible static: fixedFlavor */
#undef org_jblas_util_ArchFlavor_SSE
#define org_jblas_util_ArchFlavor_SSE 1L
#undef org_jblas_util_ArchFlavor_SSE2
diff --git a/src/main/java/org/jblas/ComplexDoubleMatrix.java b/src/main/java/org/jblas/ComplexDoubleMatrix.java
index 1f8b345..e054b59 100644
--- a/src/main/java/org/jblas/ComplexDoubleMatrix.java
+++ b/src/main/java/org/jblas/ComplexDoubleMatrix.java
@@ -57,18 +57,18 @@ public class ComplexDoubleMatrix {
*
**************************************************************************/
- /** Create a new matrix with <i>newRows</i> rows, <i>newColumns</i> columns
- * using <i>newData></i> as the data. The length of the data is not checked!
+ /**
+ * Create a new matrix with <i>newRows</i> rows, <i>newColumns</i> columns
+ * using <i>newData></i> as the data.
*/
public ComplexDoubleMatrix(int newRows, int newColumns, double... newData) {
rows = newRows;
columns = newColumns;
length = rows * columns;
- if (newData.length != 2 * newRows * newColumns)
+ if (newData.length != 2 * newRows * newColumns)
throw new IllegalArgumentException(
"Passed data must match matrix dimensions.");
-
data = newData;
}
@@ -110,25 +110,25 @@ public class ComplexDoubleMatrix {
}
- /** Construct a complex matrix from a real matrix. */
- public ComplexDoubleMatrix(DoubleMatrix m) {
- this(m.rows, m.columns);
-
- NativeBlas.dcopy(m.length, m.data, 0, 1, data, 0, 2);
- }
-
- /** Construct a complex matrix from separate real and imaginary parts. Either
- * part can be set to null in which case it will be ignored.
- */
- public ComplexDoubleMatrix(DoubleMatrix real, DoubleMatrix imag) {
- this(real.rows, real.columns);
- real.assertSameSize(imag);
-
- if (real != null)
- NativeBlas.dcopy(length, real.data, 0, 1, data, 0, 2);
- if (imag != null)
- NativeBlas.dcopy(length, imag.data, 0, 1, data, 1, 2);
- }
+ /** Construct a complex matrix from a real matrix. */
+ public ComplexDoubleMatrix(DoubleMatrix m) {
+ this(m.rows, m.columns);
+
+ NativeBlas.dcopy(m.length, m.data, 0, 1, data, 0, 2);
+ }
+
+ /** Construct a complex matrix from separate real and imaginary parts. Either
+ * part can be set to null in which case it will be ignored.
+ */
+ public ComplexDoubleMatrix(DoubleMatrix real, DoubleMatrix imag) {
+ this(real.rows, real.columns);
+ real.assertSameSize(imag);
+
+ if (real != null)
+ NativeBlas.dcopy(length, real.data, 0, 1, data, 0, 2);
+ if (imag != null)
+ NativeBlas.dcopy(length, imag.data, 0, 1, data, 1, 2);
+ }
/**
* Creates a new matrix by reading it from a file.
@@ -211,6 +211,30 @@ public class ComplexDoubleMatrix {
return m;
}
+
+ /**
+ * Construct a matrix of arbitrary shape and set the diagonal according
+ * to a passed vector.
+ *
+ * length of needs to be smaller than rows or columns.
+ *
+ * @param x vector to fill the diagonal with
+ * @param rows number of rows of the resulting matrix
+ * @param columns number of columns of the resulting matrix
+ * @return a matrix with dimensions rows * columns whose diagonal elements are filled by x
+ */
+ public static ComplexDoubleMatrix diag(ComplexDoubleMatrix x, int rows, int columns) {
+ if (x.length > rows || x.length > columns) {
+ throw new SizeException("Length of diagonal matrix must be larger than both rows and columns.");
+ }
+
+ ComplexDoubleMatrix m = new ComplexDoubleMatrix(rows, columns);
+
+ for (int i = 0; i < x.length; i++)
+ m.put(i, i, x.get(i));
+
+ return m;
+ }
/**
* Create a 1 * 1 - matrix. For many operations, this matrix functions like a
@@ -281,7 +305,7 @@ public class ComplexDoubleMatrix {
}
public ComplexDoubleMatrix get(int[] indices, int c) {
- ComplexDoubleMatrix result = new ComplexDoubleMatrix(indices.length, c);
+ ComplexDoubleMatrix result = new ComplexDoubleMatrix(indices.length, 1);
for (int i = 0; i < indices.length; i++)
result.put(i, get(indices[i], c));
diff --git a/src/main/java/org/jblas/ComplexFloat.java b/src/main/java/org/jblas/ComplexFloat.java
index c8538ab..ca25653 100644
--- a/src/main/java/org/jblas/ComplexFloat.java
+++ b/src/main/java/org/jblas/ComplexFloat.java
@@ -330,4 +330,4 @@ public class ComplexFloat {
public boolean isImag() {
return r == 0.0f;
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/jblas/ComplexFloatMatrix.java b/src/main/java/org/jblas/ComplexFloatMatrix.java
index 2b11fed..489a8c0 100644
--- a/src/main/java/org/jblas/ComplexFloatMatrix.java
+++ b/src/main/java/org/jblas/ComplexFloatMatrix.java
@@ -57,18 +57,18 @@ public class ComplexFloatMatrix {
*
**************************************************************************/
- /** Create a new matrix with <i>newRows</i> rows, <i>newColumns</i> columns
- * using <i>newData></i> as the data. The length of the data is not checked!
+ /**
+ * Create a new matrix with <i>newRows</i> rows, <i>newColumns</i> columns
+ * using <i>newData></i> as the data.
*/
public ComplexFloatMatrix(int newRows, int newColumns, float... newData) {
rows = newRows;
columns = newColumns;
length = rows * columns;
- if (newData.length != 2 * newRows * newColumns)
+ if (newData.length != 2 * newRows * newColumns)
throw new IllegalArgumentException(
"Passed data must match matrix dimensions.");
-
data = newData;
}
@@ -110,25 +110,25 @@ public class ComplexFloatMatrix {
}
- /** Construct a complex matrix from a real matrix. */
- public ComplexFloatMatrix(FloatMatrix m) {
- this(m.rows, m.columns);
-
- NativeBlas.scopy(m.length, m.data, 0, 1, data, 0, 2);
- }
-
- /** Construct a complex matrix from separate real and imaginary parts. Either
- * part can be set to null in which case it will be ignored.
- */
- public ComplexFloatMatrix(FloatMatrix real, FloatMatrix imag) {
- this(real.rows, real.columns);
- real.assertSameSize(imag);
-
- if (real != null)
- NativeBlas.scopy(length, real.data, 0, 1, data, 0, 2);
- if (imag != null)
- NativeBlas.scopy(length, imag.data, 0, 1, data, 1, 2);
- }
+ /** Construct a complex matrix from a real matrix. */
+ public ComplexFloatMatrix(FloatMatrix m) {
+ this(m.rows, m.columns);
+
+ NativeBlas.scopy(m.length, m.data, 0, 1, data, 0, 2);
+ }
+
+ /** Construct a complex matrix from separate real and imaginary parts. Either
+ * part can be set to null in which case it will be ignored.
+ */
+ public ComplexFloatMatrix(FloatMatrix real, FloatMatrix imag) {
+ this(real.rows, real.columns);
+ real.assertSameSize(imag);
+
+ if (real != null)
+ NativeBlas.scopy(length, real.data, 0, 1, data, 0, 2);
+ if (imag != null)
+ NativeBlas.scopy(length, imag.data, 0, 1, data, 1, 2);
+ }
/**
* Creates a new matrix by reading it from a file.
@@ -211,6 +211,30 @@ public class ComplexFloatMatrix {
return m;
}
+
+ /**
+ * Construct a matrix of arbitrary shape and set the diagonal according
+ * to a passed vector.
+ *
+ * length of needs to be smaller than rows or columns.
+ *
+ * @param x vector to fill the diagonal with
+ * @param rows number of rows of the resulting matrix
+ * @param columns number of columns of the resulting matrix
+ * @return a matrix with dimensions rows * columns whose diagonal elements are filled by x
+ */
+ public static ComplexFloatMatrix diag(ComplexFloatMatrix x, int rows, int columns) {
+ if (x.length > rows || x.length > columns) {
+ throw new SizeException("Length of diagonal matrix must be larger than both rows and columns.");
+ }
+
+ ComplexFloatMatrix m = new ComplexFloatMatrix(rows, columns);
+
+ for (int i = 0; i < x.length; i++)
+ m.put(i, i, x.get(i));
+
+ return m;
+ }
/**
* Create a 1 * 1 - matrix. For many operations, this matrix functions like a
@@ -281,7 +305,7 @@ public class ComplexFloatMatrix {
}
public ComplexFloatMatrix get(int[] indices, int c) {
- ComplexFloatMatrix result = new ComplexFloatMatrix(indices.length, c);
+ ComplexFloatMatrix result = new ComplexFloatMatrix(indices.length, 1);
for (int i = 0; i < indices.length; i++)
result.put(i, get(indices[i], c));
@@ -2047,4 +2071,4 @@ public class ComplexFloatMatrix {
return xori(new ComplexFloat(value));
}
//RJPP-END--------------------------------------------------------------
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/jblas/Decompose.java b/src/main/java/org/jblas/Decompose.java
index a3ada26..12af53a 100644
--- a/src/main/java/org/jblas/Decompose.java
+++ b/src/main/java/org/jblas/Decompose.java
@@ -13,8 +13,6 @@ import static org.jblas.util.Functions.min;
* Matrix which collects all kinds of decompositions.
*/
public class Decompose {
-
-//STOP
/**
* Class to hold an LU decomposition result.
*
@@ -33,8 +31,12 @@ public class Decompose {
this.u = u;
this.p = p;
}
+
+ @Override
+ public String toString() {
+ return String.format("<LUDecomposition L=%s U=%s P=%s>", l, u, p);
+ }
}
-//START
/**
* Compute LU Decomposition of a general matrix.
@@ -55,7 +57,7 @@ public class Decompose {
DoubleMatrix l = new DoubleMatrix(A.rows, min(A.rows, A.columns));
DoubleMatrix u = new DoubleMatrix(min(A.columns, A.rows), A.columns);
decomposeLowerUpper(result, l, u);
- DoubleMatrix p = Permutations.permutationMatrixFromPivotIndices(A.rows, ipiv);
+ DoubleMatrix p = Permutations.permutationDoubleMatrixFromPivotIndices(A.rows, ipiv);
return new LUDecomposition<DoubleMatrix>(l, u, p);
}
@@ -75,15 +77,15 @@ public class Decompose {
}
}
- /**
+ /**if (info )
* Compute Cholesky decomposition of A
*
* @param A symmetric, positive definite matrix (only upper half is used)
* @return upper triangular matrix U such that A = U' * U
*/
- public static DoubleMatrix cholesky(DoubleMatrix A) {
- DoubleMatrix result = A.dup();
- int info = NativeBlas.dpotrf('U', A.rows, result.data, 0, A.rows);
+ public static FloatMatrix cholesky(FloatMatrix A) {
+ FloatMatrix result = A.dup();
+ int info = NativeBlas.spotrf('U', A.rows, result.data, 0, A.rows);
if (info < 0) {
throw new LapackArgumentException("DPOTRF", -info);
} else if (info > 0) {
@@ -93,9 +95,146 @@ public class Decompose {
return result;
}
- private static void clearLower(DoubleMatrix A) {
+ private static void clearLower(FloatMatrix A) {
for (int j = 0; j < A.columns; j++)
for (int i = j + 1; i < A.rows; i++)
- A.put(i, j, 0.0);
+ A.put(i, j, 0.0f);
+ }
+
+ /**
+ * Compute LU Decomposition of a general matrix.
+ *
+ * Computes the LU decomposition using GETRF. Returns three matrices L, U, P,
+ * where L is lower diagonal, U is upper diagonal, and P is a permutation
+ * matrix such that A = P * L * U.
+ *
+ * @param A general matrix
+ * @return An LUDecomposition object.
+ */
+ public static LUDecomposition<FloatMatrix> lu(FloatMatrix A) {
+ int[] ipiv = new int[min(A.rows, A.columns)];
+ FloatMatrix result = A.dup();
+ NativeBlas.sgetrf(A.rows, A.columns, result.data, 0, A.rows, ipiv, 0);
+
+ // collect result
+ FloatMatrix l = new FloatMatrix(A.rows, min(A.rows, A.columns));
+ FloatMatrix u = new FloatMatrix(min(A.columns, A.rows), A.columns);
+ decomposeLowerUpper(result, l, u);
+ FloatMatrix p = Permutations.permutationFloatMatrixFromPivotIndices(A.rows, ipiv);
+ return new LUDecomposition<FloatMatrix>(l, u, p);
+ }
+
+ private static void decomposeLowerUpper(FloatMatrix A, FloatMatrix L, FloatMatrix U) {
+ for (int i = 0; i < A.rows; i++) {
+ for (int j = 0; j < A.columns; j++) {
+ if (i < j) {
+ U.put(i, j, A.get(i, j));
+ } else if (i == j) {
+ U.put(i, i, A.get(i, i));
+ L.put(i, i, 1.0f);
+ } else {
+ L.put(i, j, A.get(i, j));
+ }
+
+ }
+ }
+ }
+
+ /**
+ * Compute Cholesky decomposition of A
+ *
+ * @param A symmetric, positive definite matrix (only upper half is used)
+ * @return upper triangular matrix U such that A = U' * U
+ */
+ public static DoubleMatrix cholesky(DoubleMatrix A) {
+ DoubleMatrix result = A.dup();
+ int info = NativeBlas.dpotrf('U', A.rows, result.data, 0, A.rows);
+ if (info < 0) {
+ throw new LapackArgumentException("DPOTRF", -info);
+ } else if (info > 0) {
+ throw new LapackPositivityException("DPOTRF", "Minor " + info + " was negative. Matrix must be positive definite.");
+ }
+ clearLower(result);
+ return result;
+ }
+
+ private static void clearLower(DoubleMatrix A) {
+ for (int j = 0; j < A.columns; j++)
+ for (int i = j + 1; i < A.rows; i++)
+ A.put(i, j, 0.0);
+ }
+
+ /**
+ * Class to represent a QR decomposition.
+ *
+ * @param <T>
+ */
+ public static class QRDecomposition<T> {
+ public T q;
+ public T r;
+
+ QRDecomposition(T q, T r) {
+ this.q = q;
+ this.r = r;
+ }
+
+ @Override
+ public String toString() {
+ return "<Q=" + q + " R=" + r + ">";
+ }
+ }
+
+ /**
+ * QR decomposition.
+ *
+ * Decomposes (m,n) matrix A into a (m,m) matrix Q and an (m,n) matrix R such that
+ * Q is orthogonal, R is upper triangular and Q * R = A
+ *
+ * Note that if A has more rows than columns, then the lower rows of R will contain
+ * only zeros, such that the corresponding later columns of Q do not enter the computation
+ * at all. For some reason, LAPACK does not properly normalize those columns.
+ *
+ * @param A matrix
+ * @return QR decomposition
+ */
+ public static QRDecomposition<DoubleMatrix> qr(DoubleMatrix A) {
+ int minmn = min(A.rows, A.columns);
+ DoubleMatrix result = A.dup();
+ DoubleMatrix tau = new DoubleMatrix(minmn);
+ SimpleBlas.geqrf(result, tau);
+ DoubleMatrix R = new DoubleMatrix(A.rows, A.columns);
+ for (int i = 0; i < A.rows; i++) {
+ for (int j = i; j < A.columns; j++) {
+ R.put(i, j, result.get(i, j));
+ }
+ }
+ DoubleMatrix Q = DoubleMatrix.eye(A.rows);
+ SimpleBlas.ormqr('L', 'N', result, tau, Q);
+ return new QRDecomposition<DoubleMatrix>(Q, R);
+ }
+
+ /**
+ * QR decomposition.
+ *
+ * Decomposes (m,n) matrix A into a (m,m) matrix Q and an (m,n) matrix R such that
+ * Q is orthogonal, R is upper triangular and Q * R = A
+ *
+ * @param A matrix
+ * @return QR decomposition
+ */
+ public static QRDecomposition<FloatMatrix> qr(FloatMatrix A) {
+ int minmn = min(A.rows, A.columns);
+ FloatMatrix result = A.dup();
+ FloatMatrix tau = new FloatMatrix(minmn);
+ SimpleBlas.geqrf(result, tau);
+ FloatMatrix R = new FloatMatrix(A.rows, A.columns);
+ for (int i = 0; i < A.rows; i++) {
+ for (int j = i; j < A.columns; j++) {
+ R.put(i, j, result.get(i, j));
+ }
}
+ FloatMatrix Q = FloatMatrix.eye(A.rows);
+ SimpleBlas.ormqr('L', 'N', result, tau, Q);
+ return new QRDecomposition<FloatMatrix>(Q, R);
+ }
}
diff --git a/src/main/java/org/jblas/DoubleMatrix.java b/src/main/java/org/jblas/DoubleMatrix.java
index 0109047..83c00a0 100644
--- a/src/main/java/org/jblas/DoubleMatrix.java
+++ b/src/main/java/org/jblas/DoubleMatrix.java
@@ -39,6 +39,8 @@ package org.jblas;
import org.jblas.exceptions.SizeException;
import org.jblas.ranges.Range;
+import org.jblas.util.Random;
+
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.DataOutputStream;
@@ -436,9 +438,8 @@ public class DoubleMatrix implements Serializable {
public static DoubleMatrix rand(int rows, int columns) {
DoubleMatrix m = new DoubleMatrix(rows, columns);
- java.util.Random r = new java.util.Random();
for (int i = 0; i < rows * columns; i++) {
- m.data[i] = r.nextDouble();
+ m.data[i] = (double) Random.nextDouble();
}
return m;
@@ -453,9 +454,8 @@ public class DoubleMatrix implements Serializable {
public static DoubleMatrix randn(int rows, int columns) {
DoubleMatrix m = new DoubleMatrix(rows, columns);
- java.util.Random r = new java.util.Random();
for (int i = 0; i < rows * columns; i++) {
- m.data[i] = (double) r.nextGaussian();
+ m.data[i] = (double) Random.nextGaussian();
}
return m;
@@ -517,6 +517,27 @@ public class DoubleMatrix implements Serializable {
return m;
}
+ /**
+ * Construct a matrix of arbitrary shape and set the diagonal according
+ * to a passed vector.
+ *
+ * length of needs to be smaller than rows or columns.
+ *
+ * @param x vector to fill the diagonal with
+ * @param rows number of rows of the resulting matrix
+ * @param columns number of columns of the resulting matrix
+ * @return a matrix with dimensions rows * columns whose diagonal elements are filled by x
+ */
+ public static DoubleMatrix diag(DoubleMatrix x, int rows, int columns) {
+ DoubleMatrix m = new DoubleMatrix(rows, columns);
+
+ for (int i = 0; i < x.length; i++) {
+ m.put(i, i, x.get(i));
+ }
+
+ return m;
+ }
+
/**
* Create a 1-by-1 matrix. For many operations, this matrix functions like a
* normal double.
@@ -617,7 +638,7 @@ public class DoubleMatrix implements Serializable {
/** Get all elements for a given column and the specified rows. */
public DoubleMatrix get(int[] indices, int c) {
- DoubleMatrix result = new DoubleMatrix(indices.length, c);
+ DoubleMatrix result = new DoubleMatrix(indices.length, 1);
for (int i = 0; i < indices.length; i++) {
result.put(i, get(indices[i], c));
@@ -646,6 +667,7 @@ public class DoubleMatrix implements Serializable {
DoubleMatrix result = new DoubleMatrix(rs.length(), cs.length());
for (; rs.hasMore(); rs.next()) {
+ cs.init(0, columns);
for (; cs.hasMore(); cs.next()) {
result.put(rs.index(), cs.index(), get(rs.value(), cs.value()));
}
@@ -836,7 +858,12 @@ public class DoubleMatrix implements Serializable {
}
- /** Set elements in linear ordering in the specified indices. */
+ /**
+ * Set elements in linear ordering in the specified indices.
+ *
+ * For example, <code>a.put(new int[]{ 1, 2, 0 }, new DoubleMatrix(3, 1, 2.0, 4.0, 8.0)</code>
+ * does <code>a.put(1, 2.0), a.put(2, 4.0), a.put(0, 8.0)</code>.
+ */
public DoubleMatrix put(int[] indices, DoubleMatrix x) {
if (x.isScalar()) {
return put(indices, x.scalar());
@@ -904,6 +931,7 @@ public class DoubleMatrix implements Serializable {
x.checkColumns(cs.length());
for (; rs.hasMore(); rs.next()) {
+ cs.init(0, columns);
for (; cs.hasMore(); cs.next()) {
put(rs.value(), cs.value(), x.get(rs.index(), cs.index()));
}
@@ -1292,25 +1320,7 @@ public class DoubleMatrix implements Serializable {
/** Generate string representation of the matrix. */
@Override
public String toString() {
- StringBuilder s = new StringBuilder();
-
- s.append("[");
-
- for (int i = 0; i < rows; i++) {
- for (int j = 0; j < columns; j++) {
- s.append(get(i, j));
- if (j < columns - 1) {
- s.append(", ");
- }
- }
- if (i < rows - 1) {
- s.append("; ");
- }
- }
-
- s.append("]");
-
- return s.toString();
+ return toString("%f");
}
/**
@@ -1320,24 +1330,38 @@ public class DoubleMatrix implements Serializable {
* decimal point.
*/
public String toString(String fmt) {
+ return toString(fmt, "[", "]", ", ", "; ");
+ }
+
+ /**
+ * Generate string representation of the matrix, with specified
+ * format for the entries, and delimiters.
+ *
+ * @param fmt entry format (passed to String.format())
+ * @param open opening parenthesis
+ * @param close closing parenthesis
+ * @param colSep separator between columns
+ * @param rowSep separator between rows
+ */
+ public String toString(String fmt, String open, String close, String colSep, String rowSep) {
StringWriter s = new StringWriter();
PrintWriter p = new PrintWriter(s);
- p.print("[");
+ p.print(open);
for (int r = 0; r < rows; r++) {
for (int c = 0; c < columns; c++) {
p.printf(fmt, get(r, c));
if (c < columns - 1) {
- p.print(", ");
+ p.print(colSep);
}
}
if (r < rows - 1) {
- p.print("; ");
+ p.print(rowSep);
}
}
- p.print("]");
+ p.print(close);
return s.toString();
}
@@ -1790,6 +1814,30 @@ public class DoubleMatrix implements Serializable {
return dup().isInfinitei();
}
+ /** Checks whether all entries (i, j) with i >= j are zero. */
+ public boolean isLowerTriangular() {
+ for (int i = 0; i < rows; i++)
+ for (int j = i+1; j < columns; j++) {
+ if (get(i, j) != 0.0)
+ return false;
+ }
+
+ return true;
+ }
+
+ /**
+ * Checks whether all entries (i, j) with i <= j are zero.
+ */
+ public boolean isUpperTriangular() {
+ for (int i = 0; i < rows; i++)
+ for (int j = 0; j < i && j < columns; j++) {
+ if (get(i, j) != 0.0)
+ return false;
+ }
+
+ return true;
+ }
+
public DoubleMatrix selecti(DoubleMatrix where) {
checkLength(where.length);
for (int i = 0; i < length; i++) {
@@ -1893,7 +1941,7 @@ public class DoubleMatrix implements Serializable {
}
}
}
- return this;
+ return result;
}
/**
@@ -1929,7 +1977,7 @@ public class DoubleMatrix implements Serializable {
}
}
- return this;
+ return result;
}
public DoubleMatrix mini(double v) {
@@ -1995,7 +2043,7 @@ public class DoubleMatrix implements Serializable {
}
}
}
- return this;
+ return result;
}
/**
@@ -2031,7 +2079,7 @@ public class DoubleMatrix implements Serializable {
}
}
- return this;
+ return result;
}
public DoubleMatrix maxi(double v) {
@@ -2630,6 +2678,7 @@ public class DoubleMatrix implements Serializable {
public void save(String filename) throws IOException {
DataOutputStream dos = new DataOutputStream(new FileOutputStream(filename, false));
this.out(dos);
+ dos.close();
}
/**
@@ -2641,6 +2690,7 @@ public class DoubleMatrix implements Serializable {
public void load(String filename) throws IOException {
DataInputStream dis = new DataInputStream(new FileInputStream(filename));
this.in(dis);
+ dis.close();
}
public static DoubleMatrix loadAsciiFile(String filename) throws IOException {
@@ -3377,4 +3427,8 @@ public class DoubleMatrix implements Serializable {
return xori(value, new DoubleMatrix(rows, columns));
}
//RJPP-END--------------------------------------------------------------
+
+ public ComplexDoubleMatrix toComplex() {
+ return new ComplexDoubleMatrix(this);
+ }
}
diff --git a/src/main/java/org/jblas/FloatFunction.java b/src/main/java/org/jblas/FloatFunction.java
index 23cc618..91f6a88 100644
--- a/src/main/java/org/jblas/FloatFunction.java
+++ b/src/main/java/org/jblas/FloatFunction.java
@@ -42,4 +42,4 @@ package org.jblas;
public interface FloatFunction {
/** Compute the function. */
public float compute(float x);
-}
+}
\ No newline at end of file
diff --git a/src/main/java/org/jblas/FloatMatrix.java b/src/main/java/org/jblas/FloatMatrix.java
index 1251a39..4e9f5df 100644
--- a/src/main/java/org/jblas/FloatMatrix.java
+++ b/src/main/java/org/jblas/FloatMatrix.java
@@ -39,6 +39,8 @@ package org.jblas;
import org.jblas.exceptions.SizeException;
import org.jblas.ranges.Range;
+import org.jblas.util.Random;
+
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.DataOutputStream;
@@ -436,9 +438,8 @@ public class FloatMatrix implements Serializable {
public static FloatMatrix rand(int rows, int columns) {
FloatMatrix m = new FloatMatrix(rows, columns);
- java.util.Random r = new java.util.Random();
for (int i = 0; i < rows * columns; i++) {
- m.data[i] = r.nextFloat();
+ m.data[i] = (float) Random.nextFloat();
}
return m;
@@ -453,9 +454,8 @@ public class FloatMatrix implements Serializable {
public static FloatMatrix randn(int rows, int columns) {
FloatMatrix m = new FloatMatrix(rows, columns);
- java.util.Random r = new java.util.Random();
for (int i = 0; i < rows * columns; i++) {
- m.data[i] = (float) r.nextGaussian();
+ m.data[i] = (float) Random.nextGaussian();
}
return m;
@@ -517,6 +517,27 @@ public class FloatMatrix implements Serializable {
return m;
}
+ /**
+ * Construct a matrix of arbitrary shape and set the diagonal according
+ * to a passed vector.
+ *
+ * length of needs to be smaller than rows or columns.
+ *
+ * @param x vector to fill the diagonal with
+ * @param rows number of rows of the resulting matrix
+ * @param columns number of columns of the resulting matrix
+ * @return a matrix with dimensions rows * columns whose diagonal elements are filled by x
+ */
+ public static FloatMatrix diag(FloatMatrix x, int rows, int columns) {
+ FloatMatrix m = new FloatMatrix(rows, columns);
+
+ for (int i = 0; i < x.length; i++) {
+ m.put(i, i, x.get(i));
+ }
+
+ return m;
+ }
+
/**
* Create a 1-by-1 matrix. For many operations, this matrix functions like a
* normal float.
@@ -617,7 +638,7 @@ public class FloatMatrix implements Serializable {
/** Get all elements for a given column and the specified rows. */
public FloatMatrix get(int[] indices, int c) {
- FloatMatrix result = new FloatMatrix(indices.length, c);
+ FloatMatrix result = new FloatMatrix(indices.length, 1);
for (int i = 0; i < indices.length; i++) {
result.put(i, get(indices[i], c));
@@ -646,6 +667,7 @@ public class FloatMatrix implements Serializable {
FloatMatrix result = new FloatMatrix(rs.length(), cs.length());
for (; rs.hasMore(); rs.next()) {
+ cs.init(0, columns);
for (; cs.hasMore(); cs.next()) {
result.put(rs.index(), cs.index(), get(rs.value(), cs.value()));
}
@@ -836,7 +858,12 @@ public class FloatMatrix implements Serializable {
}
- /** Set elements in linear ordering in the specified indices. */
+ /**
+ * Set elements in linear ordering in the specified indices.
+ *
+ * For example, <code>a.put(new int[]{ 1, 2, 0 }, new FloatMatrix(3, 1, 2.0f, 4.0f, 8.0f)</code>
+ * does <code>a.put(1, 2.0f), a.put(2, 4.0f), a.put(0, 8.0f)</code>.
+ */
public FloatMatrix put(int[] indices, FloatMatrix x) {
if (x.isScalar()) {
return put(indices, x.scalar());
@@ -904,6 +931,7 @@ public class FloatMatrix implements Serializable {
x.checkColumns(cs.length());
for (; rs.hasMore(); rs.next()) {
+ cs.init(0, columns);
for (; cs.hasMore(); cs.next()) {
put(rs.value(), cs.value(), x.get(rs.index(), cs.index()));
}
@@ -1292,25 +1320,7 @@ public class FloatMatrix implements Serializable {
/** Generate string representation of the matrix. */
@Override
public String toString() {
- StringBuilder s = new StringBuilder();
-
- s.append("[");
-
- for (int i = 0; i < rows; i++) {
- for (int j = 0; j < columns; j++) {
- s.append(get(i, j));
- if (j < columns - 1) {
- s.append(", ");
- }
- }
- if (i < rows - 1) {
- s.append("; ");
- }
- }
-
- s.append("]");
-
- return s.toString();
+ return toString("%f");
}
/**
@@ -1320,24 +1330,38 @@ public class FloatMatrix implements Serializable {
* decimal point.
*/
public String toString(String fmt) {
+ return toString(fmt, "[", "]", ", ", "; ");
+ }
+
+ /**
+ * Generate string representation of the matrix, with specified
+ * format for the entries, and delimiters.
+ *
+ * @param fmt entry format (passed to String.format())
+ * @param open opening parenthesis
+ * @param close closing parenthesis
+ * @param colSep separator between columns
+ * @param rowSep separator between rows
+ */
+ public String toString(String fmt, String open, String close, String colSep, String rowSep) {
StringWriter s = new StringWriter();
PrintWriter p = new PrintWriter(s);
- p.print("[");
+ p.print(open);
for (int r = 0; r < rows; r++) {
for (int c = 0; c < columns; c++) {
p.printf(fmt, get(r, c));
if (c < columns - 1) {
- p.print(", ");
+ p.print(colSep);
}
}
if (r < rows - 1) {
- p.print("; ");
+ p.print(rowSep);
}
}
- p.print("]");
+ p.print(close);
return s.toString();
}
@@ -1790,6 +1814,30 @@ public class FloatMatrix implements Serializable {
return dup().isInfinitei();
}
+ /** Checks whether all entries (i, j) with i >= j are zero. */
+ public boolean isLowerTriangular() {
+ for (int i = 0; i < rows; i++)
+ for (int j = i+1; j < columns; j++) {
+ if (get(i, j) != 0.0f)
+ return false;
+ }
+
+ return true;
+ }
+
+ /**
+ * Checks whether all entries (i, j) with i <= j are zero.
+ */
+ public boolean isUpperTriangular() {
+ for (int i = 0; i < rows; i++)
+ for (int j = 0; j < i && j < columns; j++) {
+ if (get(i, j) != 0.0f)
+ return false;
+ }
+
+ return true;
+ }
+
public FloatMatrix selecti(FloatMatrix where) {
checkLength(where.length);
for (int i = 0; i < length; i++) {
@@ -1893,7 +1941,7 @@ public class FloatMatrix implements Serializable {
}
}
}
- return this;
+ return result;
}
/**
@@ -1929,7 +1977,7 @@ public class FloatMatrix implements Serializable {
}
}
- return this;
+ return result;
}
public FloatMatrix mini(float v) {
@@ -1995,7 +2043,7 @@ public class FloatMatrix implements Serializable {
}
}
}
- return this;
+ return result;
}
/**
@@ -2031,7 +2079,7 @@ public class FloatMatrix implements Serializable {
}
}
- return this;
+ return result;
}
public FloatMatrix maxi(float v) {
@@ -2630,6 +2678,7 @@ public class FloatMatrix implements Serializable {
public void save(String filename) throws IOException {
DataOutputStream dos = new DataOutputStream(new FileOutputStream(filename, false));
this.out(dos);
+ dos.close();
}
/**
@@ -2641,6 +2690,7 @@ public class FloatMatrix implements Serializable {
public void load(String filename) throws IOException {
DataInputStream dis = new DataInputStream(new FileInputStream(filename));
this.in(dis);
+ dis.close();
}
public static FloatMatrix loadAsciiFile(String filename) throws IOException {
@@ -3377,4 +3427,8 @@ public class FloatMatrix implements Serializable {
return xori(value, new FloatMatrix(rows, columns));
}
//RJPP-END--------------------------------------------------------------
-}
+
+ public ComplexFloatMatrix toComplex() {
+ return new ComplexFloatMatrix(this);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/jblas/Info.java b/src/main/java/org/jblas/Info.java
new file mode 100644
index 0000000..d69e6b1
--- /dev/null
+++ b/src/main/java/org/jblas/Info.java
@@ -0,0 +1,14 @@
+package org.jblas;
+
+/**
+ * <one line description>
+ * <p/>
+ * <longer description>
+ * <p/>
+ * User: mikio
+ * Date: 2/12/13
+ * Time: 3:28 PM
+ */
+public class Info {
+ public static String VERSION = "1.2.3";
+}
diff --git a/src/main/java/org/jblas/NativeBlas.java b/src/main/java/org/jblas/NativeBlas.java
index 32fb7fb..4adb2b4 100644
--- a/src/main/java/org/jblas/NativeBlas.java
+++ b/src/main/java/org/jblas/NativeBlas.java
@@ -74,21 +74,13 @@ import org.jblas.util.Logger;
public class NativeBlas {
static {
- try {
- System.loadLibrary("jblas");
- } catch (UnsatisfiedLinkError e) {
- Logger.getLogger().config(
- "BLAS native library not found in path. Copying native library "
- + "from the archive. Consider installing the library somewhere "
- + "in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH).");
- new org.jblas.util.LibraryLoader().loadLibrary("jblas", true);
- }
- }
- private static int[] intDummy = new int[1];
- private static double[] doubleDummy = new double[1];
- private static float[] floatDummy = new float[1];
-
-
+ NativeBlasLibraryLoader.loadLibraryAndCheckErrors();
+ }
+
+ private static int[] intDummy = new int[1];
+ private static double[] doubleDummy = new double[1];
+ private static float[] floatDummy = new float[1];
+
public static native void ccopy(int n, float[] cx, int cxIdx, int incx, float[] cy, int cyIdx, int incy);
public static native void dcopy(int n, double[] dx, int dxIdx, int incx, double[] dy, int dyIdx, int incy);
public static native void scopy(int n, float[] sx, int sxIdx, int incx, float[] sy, int syIdx, int incy);
@@ -425,5 +417,84 @@ public class NativeBlas {
return info;
}
+ public static native int dgelsd(int m, int n, int nrhs, double[] a, int aIdx, int lda, double[] b, int bIdx, int ldb, double[] s, int sIdx, double rcond, int[] rank, int rankIdx, double[] work, int workIdx, int lwork, int[] iwork, int iworkIdx);
+ public static int dgelsd(int m, int n, int nrhs, double[] a, int aIdx, int lda, double[] b, int bIdx, int ldb, double[] s, int sIdx, double rcond, int[] rank, int rankIdx, int[] iwork, int iworkIdx) {
+ int info;
+ double[] work = new double[1];
+ int lwork;
+ info = dgelsd(m, n, nrhs, doubleDummy, 0, lda, doubleDummy, 0, ldb, doubleDummy, 0, rcond, intDummy, 0, work, 0, -1, intDummy, 0);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new double[lwork];
+ info = dgelsd(m, n, nrhs, a, aIdx, lda, b, bIdx, ldb, s, sIdx, rcond, rank, rankIdx, work, 0, lwork, iwork, iworkIdx);
+ return info;
+ }
+
+ public static native int sgelsd(int m, int n, int nrhs, float[] a, int aIdx, int lda, float[] b, int bIdx, int ldb, float[] s, int sIdx, float rcond, int[] rank, int rankIdx, float[] work, int workIdx, int lwork, int[] iwork, int iworkIdx);
+ public static int sgelsd(int m, int n, int nrhs, float[] a, int aIdx, int lda, float[] b, int bIdx, int ldb, float[] s, int sIdx, float rcond, int[] rank, int rankIdx, int[] iwork, int iworkIdx) {
+ int info;
+ float[] work = new float[1];
+ int lwork;
+ info = sgelsd(m, n, nrhs, floatDummy, 0, lda, floatDummy, 0, ldb, floatDummy, 0, rcond, intDummy, 0, work, 0, -1, intDummy, 0);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new float[lwork];
+ info = sgelsd(m, n, nrhs, a, aIdx, lda, b, bIdx, ldb, s, sIdx, rcond, rank, rankIdx, work, 0, lwork, iwork, iworkIdx);
+ return info;
+ }
+
+ public static native int ilaenv(int ispec, String name, String opts, int n1, int n2, int n3, int n4);
+ public static native int dgeqrf(int m, int n, double[] a, int aIdx, int lda, double[] tau, int tauIdx, double[] work, int workIdx, int lwork);
+ public static int dgeqrf(int m, int n, double[] a, int aIdx, int lda, double[] tau, int tauIdx) {
+ int info;
+ double[] work = new double[1];
+ int lwork;
+ info = dgeqrf(m, n, doubleDummy, 0, lda, doubleDummy, 0, work, 0, -1);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new double[lwork];
+ info = dgeqrf(m, n, a, aIdx, lda, tau, tauIdx, work, 0, lwork);
+ return info;
+ }
+
+ public static native int sgeqrf(int m, int n, float[] a, int aIdx, int lda, float[] tau, int tauIdx, float[] work, int workIdx, int lwork);
+ public static int sgeqrf(int m, int n, float[] a, int aIdx, int lda, float[] tau, int tauIdx) {
+ int info;
+ float[] work = new float[1];
+ int lwork;
+ info = sgeqrf(m, n, floatDummy, 0, lda, floatDummy, 0, work, 0, -1);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new float[lwork];
+ info = sgeqrf(m, n, a, aIdx, lda, tau, tauIdx, work, 0, lwork);
+ return info;
+ }
+
+ public static native int dormqr(char side, char trans, int m, int n, int k, double[] a, int aIdx, int lda, double[] tau, int tauIdx, double[] c, int cIdx, int ldc, double[] work, int workIdx, int lwork);
+ public static int dormqr(char side, char trans, int m, int n, int k, double[] a, int aIdx, int lda, double[] tau, int tauIdx, double[] c, int cIdx, int ldc) {
+ int info;
+ double[] work = new double[1];
+ int lwork;
+ info = dormqr(side, trans, m, n, k, doubleDummy, 0, lda, doubleDummy, 0, doubleDummy, 0, ldc, work, 0, -1);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new double[lwork];
+ info = dormqr(side, trans, m, n, k, a, aIdx, lda, tau, tauIdx, c, cIdx, ldc, work, 0, lwork);
+ return info;
+ }
+
+ public static native int sormqr(char side, char trans, int m, int n, int k, float[] a, int aIdx, int lda, float[] tau, int tauIdx, float[] c, int cIdx, int ldc, float[] work, int workIdx, int lwork);
+ public static int sormqr(char side, char trans, int m, int n, int k, float[] a, int aIdx, int lda, float[] tau, int tauIdx, float[] c, int cIdx, int ldc) {
+ int info;
+ float[] work = new float[1];
+ int lwork;
+ info = sormqr(side, trans, m, n, k, floatDummy, 0, lda, floatDummy, 0, floatDummy, 0, ldc, work, 0, -1);
+ if (info != 0)
+ return info;
+ lwork = (int) work[0]; work = new float[lwork];
+ info = sormqr(side, trans, m, n, k, a, aIdx, lda, tau, tauIdx, c, cIdx, ldc, work, 0, lwork);
+ return info;
+ }
+
}
diff --git a/src/main/java/org/jblas/NativeBlasLibraryLoader.java b/src/main/java/org/jblas/NativeBlasLibraryLoader.java
new file mode 100644
index 0000000..3460bfb
--- /dev/null
+++ b/src/main/java/org/jblas/NativeBlasLibraryLoader.java
@@ -0,0 +1,70 @@
+package org.jblas;
+
+import org.jblas.exceptions.UnsupportedArchitectureException;
+import org.jblas.util.LibraryLoader;
+import org.jblas.util.Logger;
+
+/**
+ * Help class for loading libraries needed for NativeBlas
+ *
+ * The only use of this class is to have NativeBlas inherit from this class.
+ *
+ * User: Mikio L. Braun
+ * Date: 10/24/12
+ * Time: 3:15 PM
+ */
+class NativeBlasLibraryLoader {
+ static void loadLibraryAndCheckErrors() {
+ try {
+ try {
+ loadDependentLibraries();
+ System.loadLibrary("jblas");
+ } catch (UnsatisfiedLinkError e) {
+ Logger.getLogger().config(
+ "BLAS native library not found in path. Copying native library "
+ + "from the archive. Consider installing the library somewhere "
+ + "in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH).");
+ new LibraryLoader().loadLibrary("jblas", true, false);
+ }
+ // Let's do some quick tests to see whether we trigger some errors
+ // when dependent libraries cannot be found
+ double[] a = new double[1];
+ NativeBlas.dgemm('N', 'N', 1, 1, 1, 1.0, a, 0, 1, a, 0, 1, 1.0, a, 0, 1);
+ } catch (UnsatisfiedLinkError e) {
+ String arch = System.getProperty("os.arch");
+ String name = System.getProperty("os.name");
+
+ if (name.startsWith("Windows") && e.getMessage().contains("Can't find dependent libraries")) {
+ System.err.println("On Windows, you need some additional support libraries.\n" +
+ "For example, you can install the two packages in cygwin:\n" +
+ "\n" +
+ " mingw64-x86_64-gcc-core mingw64-x86_64-gfortran\n" +
+ "\n" +
+ "and add the directory <cygwin-home>\\usr\\x86_64-w64-mingw32\\sys-root\\mingw\\bin to your path.\n\n" +
+ "For more information, see http://github.com/mikiobraun/jblas/wiki/Missing-Libraries");
+ } else if (name.equals("Linux") && arch.equals("amd64")) {
+ System.err.println("On Linux 64bit, you need additional support libraries.\n" +
+ "You need to install libgfortran3.\n\n" +
+ "For example for debian or Ubuntu, type \"sudo apt-get install libgfortran3\"\n\n" +
+ "For more information, see https://github.com/mikiobraun/jblas/wiki/Missing-Libraries");
+ }
+ } catch (UnsupportedArchitectureException e) {
+ System.err.println(e.getMessage());
+ }
+ }
+
+ public static void loadDependentLibraries() {
+ String arch = System.getProperty("os.arch");
+ String name = System.getProperty("os.name");
+
+ LibraryLoader loader = new LibraryLoader();
+
+ if (name.startsWith("Windows") && arch.equals("amd64")) {
+ loader.loadLibrary("libgcc_s_sjlj-1", false, true);
+ loader.loadLibrary("libgfortran-3", false, true);
+ } else if (name.startsWith("Windows") && arch.equals("x86")) {
+ loader.loadLibrary("libgcc_s_dw2-1", false, true);
+ loader.loadLibrary("libgfortran-3", false, true);
+ }
+ }
+}
diff --git a/src/main/java/org/jblas/SimpleBlas.java b/src/main/java/org/jblas/SimpleBlas.java
index 56a7b93..0885a4a 100644
--- a/src/main/java/org/jblas/SimpleBlas.java
+++ b/src/main/java/org/jblas/SimpleBlas.java
@@ -37,10 +37,10 @@
package org.jblas;
-import org.jblas.exceptions.LapackException;
-import org.jblas.exceptions.LapackArgumentException;
-import org.jblas.exceptions.LapackConvergenceException;
-import org.jblas.exceptions.LapackSingularityException;
+import org.jblas.exceptions.*;
+import org.jblas.util.Functions;
+
+import static org.jblas.util.Functions.*;
//import edu.ida.core.OutputValue;
@@ -161,8 +161,14 @@ public class SimpleBlas {
return NativeBlas.idamax(x.length, x.data, 0, 1) - 1;
}
+ /**
+ * Compute index of element with largest absolute value (complex version).
+ *
+ * @param x matrix
+ * @return index of element with largest absolute value.
+ */
public static int iamax(ComplexDoubleMatrix x) {
- return NativeBlas.izamax(x.length, x.data, 0, 1);
+ return NativeBlas.izamax(x.length, x.data, 0, 1) - 1;
}
/***************************************************************************
@@ -393,6 +399,63 @@ public class SimpleBlas {
}
}
+ /**
+ * Generalized Least Squares via *GELSD.
+ *
+ * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows
+ * than columns.
+ *
+ * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain
+ * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix.
+ *
+ * Likewise, if m > n, the solution consists only of the first n rows of B.
+ *
+ * @param A an (m,n) matrix
+ * @param B an (max(m,n), k) matrix (well, at least)
+ */
+ public static void gelsd(DoubleMatrix A, DoubleMatrix B) {
+ int m = A.rows;
+ int n = A.columns;
+ int nrhs = B.columns;
+ int minmn = min(m, n);
+ int maxmn = max(m, n);
+
+ if (B.rows < maxmn) {
+ throw new SizeException("Result matrix B must be padded to contain the solution matrix X!");
+ }
+
+ int smlsiz = NativeBlas.ilaenv(9, "DGELSD", "", m, n, nrhs, 0);
+ int nlvl = max(0, (int) log2(minmn/ (smlsiz+1)) + 1);
+
+// System.err.printf("GELSD\n");
+// System.err.printf("m = %d, n = %d, nrhs = %d\n", m, n, nrhs);
+// System.err.printf("smlsiz = %d, nlvl = %d\n", smlsiz, nlvl);
+// System.err.printf("iwork size = %d\n", 3 * minmn * nlvl + 11 * minmn);
+
+ int[] iwork = new int[3 * minmn * nlvl + 11 * minmn];
+ double[] s = new double[minmn];
+ int[] rank = new int[1];
+ int info = NativeBlas.dgelsd(m, n, nrhs, A.data, 0, m, B.data, 0, B.rows, s, 0, -1, rank, 0, iwork, 0);
+ if (info == 0) {
+ return;
+ } else if (info < 0) {
+ throw new LapackArgumentException("DGESD", -info);
+ } else if (info > 0) {
+ throw new LapackConvergenceException("DGESD", info + " off-diagonal elements of an intermediat bidiagonal form did not converge to 0.");
+ }
+ }
+
+ public static void geqrf(DoubleMatrix A, DoubleMatrix tau) {
+ int info = NativeBlas.dgeqrf(A.rows, A.columns, A.data, 0, A.rows, tau.data, 0);
+ checkInfo("GEQRF", info);
+ }
+
+ public static void ormqr(char side, char trans, DoubleMatrix A, DoubleMatrix tau, DoubleMatrix C) {
+ int k = tau.length;
+ int info = NativeBlas.dormqr(side, trans, C.rows, C.columns, k, A.data, 0, A.rows, tau.data, 0, C.data, 0, C.rows);
+ checkInfo("ORMQR", info);
+ }
+
//BEGIN
// The code below has been automatically generated.
// DO NOT EDIT!
@@ -503,8 +566,14 @@ public class SimpleBlas {
return NativeBlas.isamax(x.length, x.data, 0, 1) - 1;
}
+ /**
+ * Compute index of element with largest absolute value (complex version).
+ *
+ * @param x matrix
+ * @return index of element with largest absolute value.
+ */
public static int iamax(ComplexFloatMatrix x) {
- return NativeBlas.icamax(x.length, x.data, 0, 1);
+ return NativeBlas.icamax(x.length, x.data, 0, 1) - 1;
}
/***************************************************************************
@@ -728,5 +797,62 @@ public class SimpleBlas {
}
}
+ /**
+ * Generalized Least Squares via *GELSD.
+ *
+ * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows
+ * than columns.
+ *
+ * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain
+ * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix.
+ *
+ * Likewise, if m > n, the solution consists only of the first n rows of B.
+ *
+ * @param A an (m,n) matrix
+ * @param B an (max(m,n), k) matrix (well, at least)
+ */
+ public static void gelsd(FloatMatrix A, FloatMatrix B) {
+ int m = A.rows;
+ int n = A.columns;
+ int nrhs = B.columns;
+ int minmn = min(m, n);
+ int maxmn = max(m, n);
+
+ if (B.rows < maxmn) {
+ throw new SizeException("Result matrix B must be padded to contain the solution matrix X!");
+ }
+
+ int smlsiz = NativeBlas.ilaenv(9, "DGELSD", "", m, n, nrhs, 0);
+ int nlvl = max(0, (int) log2(minmn/ (smlsiz+1)) + 1);
+
+// System.err.printf("GELSD\n");
+// System.err.printf("m = %d, n = %d, nrhs = %d\n", m, n, nrhs);
+// System.err.printf("smlsiz = %d, nlvl = %d\n", smlsiz, nlvl);
+// System.err.printf("iwork size = %d\n", 3 * minmn * nlvl + 11 * minmn);
+
+ int[] iwork = new int[3 * minmn * nlvl + 11 * minmn];
+ float[] s = new float[minmn];
+ int[] rank = new int[1];
+ int info = NativeBlas.sgelsd(m, n, nrhs, A.data, 0, m, B.data, 0, B.rows, s, 0, -1, rank, 0, iwork, 0);
+ if (info == 0) {
+ return;
+ } else if (info < 0) {
+ throw new LapackArgumentException("DGESD", -info);
+ } else if (info > 0) {
+ throw new LapackConvergenceException("DGESD", info + " off-diagonal elements of an intermediat bidiagonal form did not converge to 0.");
+ }
+ }
+
+ public static void geqrf(FloatMatrix A, FloatMatrix tau) {
+ int info = NativeBlas.sgeqrf(A.rows, A.columns, A.data, 0, A.rows, tau.data, 0);
+ checkInfo("GEQRF", info);
+ }
+
+ public static void ormqr(char side, char trans, FloatMatrix A, FloatMatrix tau, FloatMatrix C) {
+ int k = tau.length;
+ int info = NativeBlas.sormqr(side, trans, C.rows, C.columns, k, A.data, 0, A.rows, tau.data, 0, C.data, 0, C.rows);
+ checkInfo("ORMQR", info);
+ }
+
//END
}
diff --git a/src/main/java/org/jblas/Singular.java b/src/main/java/org/jblas/Singular.java
index 424eb12..45d169f 100644
--- a/src/main/java/org/jblas/Singular.java
+++ b/src/main/java/org/jblas/Singular.java
@@ -4,6 +4,8 @@
*/
package org.jblas;
+import org.jblas.exceptions.LapackConvergenceException;
+
import static org.jblas.util.Functions.min;
/**
@@ -24,7 +26,11 @@ public class Singular {
DoubleMatrix S = new DoubleMatrix(min(m, n));
DoubleMatrix V = new DoubleMatrix(n, n);
- NativeBlas.dgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n);
+ int info = NativeBlas.dgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return new DoubleMatrix[]{U, S, V.transpose()};
}
@@ -32,7 +38,7 @@ public class Singular {
/**
* Compute a singular-value decomposition of A (sparse variant).
* Sparse means that the matrices U and V are not square but
- * only have as many columns (or rows) as possible.
+ * only have as many columns (or rows) as necessary.
*
* @param A
* @return A DoubleMatrix[3] array of U, S, V such that A = U * diag(S) * V'
@@ -45,11 +51,23 @@ public class Singular {
DoubleMatrix S = new DoubleMatrix(min(m, n));
DoubleMatrix V = new DoubleMatrix(min(m, n), n);
- NativeBlas.dgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n));
+ int info = NativeBlas.dgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n));
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return new DoubleMatrix[]{U, S, V.transpose()};
}
+ /**
+ * Compute a singular-value decomposition of A (sparse variant).
+ * Sparse means that the matrices U and V are not square but only have
+ * as many columns (or rows) as necessary.
+ *
+ * @param A
+ * @return A ComplexDoubleMatrix[3] array of U, S, V such that A = U * diag(S) * V*
+ */
public static ComplexDoubleMatrix[] sparseSVD(ComplexDoubleMatrix A) {
int m = A.rows;
int n = A.columns;
@@ -60,9 +78,37 @@ public class Singular {
double[] rwork = new double[5*min(m,n)];
- NativeBlas.zgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n), rwork, 0);
+ int info = NativeBlas.zgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n), rwork, 0);
- return new ComplexDoubleMatrix[]{U, new ComplexDoubleMatrix(S), V.transpose()};
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
+
+ return new ComplexDoubleMatrix[]{U, new ComplexDoubleMatrix(S), V.hermitian()};
+ }
+
+ /**
+ * Compute a singular-value decomposition of A.
+ *
+ * @return A ComplexDoubleMatrix[3] array of U, S, V such that A = U * diag(S) * V'
+ */
+ public static ComplexDoubleMatrix[] fullSVD(ComplexDoubleMatrix A) {
+ int m = A.rows;
+ int n = A.columns;
+
+ ComplexDoubleMatrix U = new ComplexDoubleMatrix(m, m);
+ DoubleMatrix S = new DoubleMatrix(min(m, n));
+ ComplexDoubleMatrix V = new ComplexDoubleMatrix(n, n);
+
+ double[] rwork = new double[5*min(m,n)];
+
+ int info = NativeBlas.zgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n, rwork, 0);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
+
+ return new ComplexDoubleMatrix[]{U, new ComplexDoubleMatrix(S), V.hermitian()};
}
/**
@@ -76,7 +122,11 @@ public class Singular {
int n = A.columns;
DoubleMatrix S = new DoubleMatrix(min(m, n));
- NativeBlas.dgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, 1);
+ int info = NativeBlas.dgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, 1);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return S;
}
@@ -93,7 +143,11 @@ public class Singular {
DoubleMatrix S = new DoubleMatrix(min(m, n));
double[] rwork = new double[5*min(m,n)];
- NativeBlas.zgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, min(m,n), rwork, 0);
+ int info = NativeBlas.zgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, min(m,n), rwork, 0);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return S;
}
@@ -115,7 +169,11 @@ public class Singular {
FloatMatrix S = new FloatMatrix(min(m, n));
FloatMatrix V = new FloatMatrix(n, n);
- NativeBlas.sgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n);
+ int info = NativeBlas.sgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return new FloatMatrix[]{U, S, V.transpose()};
}
@@ -123,7 +181,7 @@ public class Singular {
/**
* Compute a singular-value decomposition of A (sparse variant).
* Sparse means that the matrices U and V are not square but
- * only have as many columns (or rows) as possible.
+ * only have as many columns (or rows) as necessary.
*
* @param A
* @return A FloatMatrix[3] array of U, S, V such that A = U * diag(S) * V'
@@ -136,11 +194,23 @@ public class Singular {
FloatMatrix S = new FloatMatrix(min(m, n));
FloatMatrix V = new FloatMatrix(min(m, n), n);
- NativeBlas.sgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n));
+ int info = NativeBlas.sgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n));
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return new FloatMatrix[]{U, S, V.transpose()};
}
+ /**
+ * Compute a singular-value decomposition of A (sparse variant).
+ * Sparse means that the matrices U and V are not square but only have
+ * as many columns (or rows) as necessary.
+ *
+ * @param A
+ * @return A ComplexFloatMatrix[3] array of U, S, V such that A = U * diag(S) * V*
+ */
public static ComplexFloatMatrix[] sparseSVD(ComplexFloatMatrix A) {
int m = A.rows;
int n = A.columns;
@@ -151,9 +221,37 @@ public class Singular {
float[] rwork = new float[5*min(m,n)];
- NativeBlas.cgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n), rwork, 0);
+ int info = NativeBlas.cgesvd('S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n), rwork, 0);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
+
+ return new ComplexFloatMatrix[]{U, new ComplexFloatMatrix(S), V.hermitian()};
+ }
+
+ /**
+ * Compute a singular-value decomposition of A.
+ *
+ * @return A ComplexFloatMatrix[3] array of U, S, V such that A = U * diag(S) * V'
+ */
+ public static ComplexFloatMatrix[] fullSVD(ComplexFloatMatrix A) {
+ int m = A.rows;
+ int n = A.columns;
+
+ ComplexFloatMatrix U = new ComplexFloatMatrix(m, m);
+ FloatMatrix S = new FloatMatrix(min(m, n));
+ ComplexFloatMatrix V = new ComplexFloatMatrix(n, n);
+
+ float[] rwork = new float[5*min(m,n)];
- return new ComplexFloatMatrix[]{U, new ComplexFloatMatrix(S), V.transpose()};
+ int info = NativeBlas.cgesvd('A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n, rwork, 0);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
+
+ return new ComplexFloatMatrix[]{U, new ComplexFloatMatrix(S), V.hermitian()};
}
/**
@@ -167,7 +265,11 @@ public class Singular {
int n = A.columns;
FloatMatrix S = new FloatMatrix(min(m, n));
- NativeBlas.sgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, 1);
+ int info = NativeBlas.sgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, 1);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return S;
}
@@ -184,7 +286,11 @@ public class Singular {
FloatMatrix S = new FloatMatrix(min(m, n));
float[] rwork = new float[5*min(m,n)];
- NativeBlas.cgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, min(m,n), rwork, 0);
+ int info = NativeBlas.cgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, min(m,n), rwork, 0);
+
+ if (info > 0) {
+ throw new LapackConvergenceException("GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
+ }
return S;
}
diff --git a/src/main/java/org/jblas/Solve.java b/src/main/java/org/jblas/Solve.java
index a6ec241..948ed71 100644
--- a/src/main/java/org/jblas/Solve.java
+++ b/src/main/java/org/jblas/Solve.java
@@ -36,6 +36,8 @@
package org.jblas;
+import org.jblas.util.Functions;
+
/**
* Solving linear equations.
*/
@@ -67,6 +69,44 @@ public class Solve {
return X;
}
+ /** Computes the Least Squares solution for over or underdetermined
+ * linear equations A*X = B
+ *
+ * In the overdetermined case, when m > n, that is, there are more equations than
+ * variables, it computes the least squares solution of X -> ||A*X - B ||_2.
+ *
+ * In the underdetermined case, when m < n (less equations than variables), there are infinitely
+ * many solutions and it computes the minimum norm solution.
+ *
+ * @param A an (m,n) matrix
+ * @param B a (m,k) matrix
+ * @return either the minimum norm or least squares solution.
+ */
+ public static DoubleMatrix solveLeastSquares(DoubleMatrix A, DoubleMatrix B) {
+ if (B.rows < A.columns) {
+ DoubleMatrix X = DoubleMatrix.concatVertically(B, new DoubleMatrix(A.columns - B.rows, B.columns));
+ SimpleBlas.gelsd(A.dup(), X);
+ return X;
+ } else {
+ DoubleMatrix X = B.dup();
+ SimpleBlas.gelsd(A.dup(), X);
+ return X.getRange(0, A.columns, 0, B.columns);
+ }
+ }
+
+ /**
+ * Computes the pseudo-inverse.
+ *
+ * Note, this function uses the solveLeastSquares and might produce different numerical
+ * solutions for the underdetermined case than matlab.
+ *
+ * @param A rectangular matrix
+ * @return matrix P such that A*P*A = A and P*A*P = P.
+ */
+ public static DoubleMatrix pinv(DoubleMatrix A) {
+ return solveLeastSquares(A, DoubleMatrix.eye(A.rows));
+ }
+
//BEGIN
// The code below has been automatically generated.
// DO NOT EDIT!
@@ -97,5 +137,43 @@ public class Solve {
return X;
}
+ /** Computes the Least Squares solution for over or underdetermined
+ * linear equations A*X = B
+ *
+ * In the overdetermined case, when m > n, that is, there are more equations than
+ * variables, it computes the least squares solution of X -> ||A*X - B ||_2.
+ *
+ * In the underdetermined case, when m < n (less equations than variables), there are infinitely
+ * many solutions and it computes the minimum norm solution.
+ *
+ * @param A an (m,n) matrix
+ * @param B a (m,k) matrix
+ * @return either the minimum norm or least squares solution.
+ */
+ public static FloatMatrix solveLeastSquares(FloatMatrix A, FloatMatrix B) {
+ if (B.rows < A.columns) {
+ FloatMatrix X = FloatMatrix.concatVertically(B, new FloatMatrix(A.columns - B.rows, B.columns));
+ SimpleBlas.gelsd(A.dup(), X);
+ return X;
+ } else {
+ FloatMatrix X = B.dup();
+ SimpleBlas.gelsd(A.dup(), X);
+ return X.getRange(0, A.columns, 0, B.columns);
+ }
+ }
+
+ /**
+ * Computes the pseudo-inverse.
+ *
+ * Note, this function uses the solveLeastSquares and might produce different numerical
+ * solutions for the underdetermined case than matlab.
+ *
+ * @param A rectangular matrix
+ * @return matrix P such that A*P*A = A and P*A*P = P.
+ */
+ public static FloatMatrix pinv(FloatMatrix A) {
+ return solveLeastSquares(A, FloatMatrix.eye(A.rows));
+ }
+
//END
}
diff --git a/src/main/java/org/jblas/benchmark/Main.java b/src/main/java/org/jblas/benchmark/Main.java
index a3e98c7..1a5a32f 100644
--- a/src/main/java/org/jblas/benchmark/Main.java
+++ b/src/main/java/org/jblas/benchmark/Main.java
@@ -62,7 +62,8 @@ public class Main {
+ " --arch-flavor=value overriding arch flavor (e.g. --arch-flavor=sse2)%n"
+ " --skip-java don't run java benchmarks%n"
+ " --help show this help%n"
- + " --debug set config levels to debug%n");
+ + " --debug set config levels to debug%n"
+ + "%njblas version " + org.jblas.Info.VERSION + "%n");
}
public static void main(String[] args) {
@@ -72,6 +73,10 @@ public class Main {
boolean skipJava = false;
boolean unrecognizedOptions = false;
+ Logger log = Logger.getLogger();
+
+ log.info("jblas version is " + org.jblas.Info.VERSION);
+
for (String arg : args) {
if (arg.startsWith("--")) {
int i = arg.indexOf('=');
diff --git a/src/main/java/org/jblas/exceptions/UnsupportedArchitectureException.java b/src/main/java/org/jblas/exceptions/UnsupportedArchitectureException.java
new file mode 100644
index 0000000..6fa3520
--- /dev/null
+++ b/src/main/java/org/jblas/exceptions/UnsupportedArchitectureException.java
@@ -0,0 +1,16 @@
+package org.jblas.exceptions;
+
+/**
+ * <one line description>
+ * <p/>
+ * <longer description>
+ * <p/>
+ * User: mikio
+ * Date: 2/13/13
+ * Time: 12:28 PM
+ */
+public class UnsupportedArchitectureException extends RuntimeException {
+ public UnsupportedArchitectureException(String message) {
+ super(message);
+ }
+}
diff --git a/src/main/java/org/jblas/ranges/AllRange.java b/src/main/java/org/jblas/ranges/AllRange.java
index cc1b804..81c0e69 100644
--- a/src/main/java/org/jblas/ranges/AllRange.java
+++ b/src/main/java/org/jblas/ranges/AllRange.java
@@ -48,38 +48,44 @@ import org.jblas.*;
* the ":" index in matlab. Don't forget to call init() before using this range.
*/
public class AllRange implements Range {
- private int lower;
- private int upper;
- private int value;
- private int counter;
-
- public AllRange() {}
-
- public void init(int l, int u) {
- lower = l;
- upper = u;
- value = l;
- counter = 0;
- }
-
- public int length() {
- return upper - lower;
- }
-
- public int value() {
- return value;
- }
-
- public int index() {
- return counter;
- }
-
- public void next() {
- counter++;
- value++;
- }
-
- public boolean hasMore() {
- return value < upper;
- }
+ private int lower;
+ private int upper;
+ private int value;
+ private int counter;
+
+ public AllRange() {
+ }
+
+ public void init(int l, int u) {
+ lower = l;
+ upper = u;
+ value = l;
+ counter = 0;
+ }
+
+ public int length() {
+ return upper - lower;
+ }
+
+ public int value() {
+ return value;
+ }
+
+ public int index() {
+ return counter;
+ }
+
+ public void next() {
+ counter++;
+ value++;
+ }
+
+ public boolean hasMore() {
+ return value < upper;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("<AllRange from %d to %d, with length %d, index=%d, value=%d>", lower, upper, length(), index(), value());
+ }
}
diff --git a/src/main/java/org/jblas/ranges/IntervalRange.java b/src/main/java/org/jblas/ranges/IntervalRange.java
index 65e5f06..763b2a9 100644
--- a/src/main/java/org/jblas/ranges/IntervalRange.java
+++ b/src/main/java/org/jblas/ranges/IntervalRange.java
@@ -42,44 +42,51 @@ package org.jblas.ranges;
/**
* Range which varies from a given interval. Endpoint is exclusive!
- *
+ * <p/>
* "new IntervalRange(0, 3)" enumerates 0, 1, 2.
*/
public class IntervalRange implements Range {
- private int start;
- private int end;
- private int value;
+ private int start;
+ private int end;
+ private int value;
- /** Construct new interval range. Endpoints are inclusive. */
- public IntervalRange(int a, int b) {
- start = a;
- end = b;
- }
+ /**
+ * Construct new interval range. Endpoints are inclusive.
+ */
+ public IntervalRange(int a, int b) {
+ start = a;
+ end = b;
+ }
- public void init(int lower, int upper) {
- value = start;
- if (start < lower || end > upper + 1) {
- throw new IllegalArgumentException("Bounds " + lower + " to " + upper + " are beyond range interval " + start + " to " + end + ".");
- }
+ public void init(int lower, int upper) {
+ value = start;
+ if (start < lower || end > upper + 1) {
+ throw new IllegalArgumentException("Bounds " + lower + " to " + upper + " are beyond range interval " + start + " to " + end + ".");
}
+ }
- public int length() {
- return end - start;
- }
+ public int length() {
+ return end - start;
+ }
- public void next() {
- value++;
- }
-
- public int index() {
- return value;
- }
-
- public int value() {
- return value;
- }
+ public void next() {
+ value++;
+ }
- public boolean hasMore() {
- return value < end;
- }
+ public int index() {
+ return value - start;
+ }
+
+ public int value() {
+ return value;
+ }
+
+ public boolean hasMore() {
+ return value < end;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("<Interval Range from %d to %d, length %d index=%d value=%d>", start, end, length(), index(), value());
+ }
}
diff --git a/src/main/java/org/jblas/ranges/PointRange.java b/src/main/java/org/jblas/ranges/PointRange.java
index b6d546a..b29fbd0 100644
--- a/src/main/java/org/jblas/ranges/PointRange.java
+++ b/src/main/java/org/jblas/ranges/PointRange.java
@@ -45,35 +45,42 @@ package org.jblas.ranges;
* A PointRange is a range which only has a single point.
*/
public class PointRange implements Range {
- private int value;
- private boolean consumed;
+ private int value;
+ private boolean consumed;
- /** Construct a new PointRange with the one given index. */
- public PointRange(int v) {
- value = v;
- }
-
- public void init(int l, int u) {
- consumed = false;
- }
+ /**
+ * Construct a new PointRange with the one given index.
+ */
+ public PointRange(int v) {
+ value = v;
+ }
- public int length() {
- return 1;
- }
-
- public int value() {
- return value;
- }
-
- public int index() {
- return 0;
- }
-
- public void next() {
- consumed = true;
- }
-
- public boolean hasMore() {
- return !consumed;
- }
+ public void init(int l, int u) {
+ consumed = false;
+ }
+
+ public int length() {
+ return 1;
+ }
+
+ public int value() {
+ return value;
+ }
+
+ public int index() {
+ return 0;
+ }
+
+ public void next() {
+ consumed = true;
+ }
+
+ public boolean hasMore() {
+ return !consumed;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("<PointRange at=%d>", value);
+ }
}
diff --git a/src/main/java/org/jblas/util/ArchFlavor.java b/src/main/java/org/jblas/util/ArchFlavor.java
index 1438ff5..91e06e0 100644
--- a/src/main/java/org/jblas/util/ArchFlavor.java
+++ b/src/main/java/org/jblas/util/ArchFlavor.java
@@ -41,13 +41,13 @@ package org.jblas.util;
public class ArchFlavor {
static {
- try {
+ try {
System.loadLibrary("jblas_arch_flavor");
} catch (UnsatisfiedLinkError e) {
Logger.getLogger().config("ArchFlavor native library not found in path. Copying native library "
+ "libjblas_arch_flavor from the archive. Consider installing the library somewhere "
+ "in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH).");
- new org.jblas.util.LibraryLoader().loadLibrary("jblas_arch_flavor", false);
+ new org.jblas.util.LibraryLoader().loadLibrary("jblas_arch_flavor", false, false);
}
}
private static String fixedFlavor = null;
@@ -66,6 +66,7 @@ public class ArchFlavor {
String arch = System.getProperty("os.arch");
String name = System.getProperty("os.name");
+
if (name.startsWith("Windows") && arch.equals("amd64")) {
return null;
}
diff --git a/src/main/java/org/jblas/util/Functions.java b/src/main/java/org/jblas/util/Functions.java
index d2d9440..17449ef 100644
--- a/src/main/java/org/jblas/util/Functions.java
+++ b/src/main/java/org/jblas/util/Functions.java
@@ -46,4 +46,10 @@ public class Functions {
public static int min(int a, int b) { return a < b ? a : b; }
public static int max(int a, int b) { return a > b ? a : b; }
+
+ private static final double LOG2 = 0.6931471805599453;
+
+ public static double log2(double x) {
+ return Math.log(x) / LOG2;
+ }
}
diff --git a/src/main/java/org/jblas/util/LibraryLoader.java b/src/main/java/org/jblas/util/LibraryLoader.java
index b5862f4..481603c 100644
--- a/src/main/java/org/jblas/util/LibraryLoader.java
+++ b/src/main/java/org/jblas/util/LibraryLoader.java
@@ -1,181 +1,224 @@
-// --- BEGIN LICENSE BLOCK ---
-/*
- * Copyright (c) 2009, Mikio L. Braun
- * All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are
- * met:
- *
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- *
- * * Redistributions in binary form must reproduce the above
- * copyright notice, this list of conditions and the following
- * disclaimer in the documentation and/or other materials provided
- * with the distribution.
- *
- * * Neither the name of the Technische Universität Berlin nor the
- * names of its contributors may be used to endorse or promote
- * products derived from this software without specific prior
- * written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- */
-// --- END LICENSE BLOCK ---
-package org.jblas.util;
-
-import java.io.*;
-
-/**
- * Class which allows to load a dynamic file as resource (for example, from a
- * jar-file)
- */
-public class LibraryLoader {
-
- private Logger logger;
- private String libpath;
-
- public LibraryLoader() {
- logger = Logger.getLogger();
- libpath = null;
- }
-
- /**
- * <p>Find the library <tt>libname</tt> as a resource, copy it to a tempfile
- * and load it using System.load(). The name of the library has to be the
- * base name, it is mapped to the corresponding system name using
- * System.mapLibraryName(). For example, the library "foo" is called "libfoo.so"
- * under Linux and "foo.dll" under Windows, but you just have to pass "foo"
- * the loadLibrary().</p>
- *
- * <p>I'm not quite sure if this doesn't open all kinds of security holes. Any ideas?</p>
- *
- * <p>This function reports some more information to the "org.jblas" logger at
- * the FINE level.</p>
- *
- * @param libname basename of the library
- * @throws UnsatisfiedLinkError if library cannot be founds
- */
- public void loadLibrary(String libname, boolean withFlavor) {
- // preload flavor libraries
- String flavor = null;
- if (withFlavor) {
- logger.debug("Preloading ArchFlavor library.");
- flavor = ArchFlavor.archFlavor();
- }
-
- libname = System.mapLibraryName(libname);
- logger.debug("Attempting to load \"" + libname + "\".");
-
- String[] paths = {
- "/",
- "/bin/",
- fatJarLibraryPath("static", flavor),
- fatJarLibraryPathNonUnified("static", flavor),
- fatJarLibraryPath("dynamic", flavor),
- fatJarLibraryPathNonUnified("dynamic", flavor),
- };
-
- InputStream is = findLibrary(paths, libname);
-
- // Oh man, have to get out of here!
- if (is == null) {
- throw new UnsatisfiedLinkError("Couldn't find the resource " + libname + ".");
- }
-
- logger.config("Loading " + libname + " from " + libpath + ".");
- loadLibraryFromStream(libname, is);
- }
-
- private InputStream findLibrary(String[] paths, String libname) {
- InputStream is = null;
- for (String path: paths) {
- is = tryPath(path + libname);
- if (is != null) {
- libpath = path;
- break;
- }
- }
- return is;
- }
-
- /** Translate all those Windows to "Windows". ("Windows XP", "Windows Vista", "Windows 7", etc.) */
- private String unifyOSName(String osname) {
- if (osname.startsWith("Windows")) {
- return "Windows";
- }
- return osname;
- }
-
- /** Compute the path to the library. The path is basically
- "/" + os.name + "/" + os.arch + "/" + libname. */
- private String fatJarLibraryPath(String linkage, String flavor) {
- String sep = "/"; //System.getProperty("file.separator");
- String os_name = unifyOSName(System.getProperty("os.name"));
- String os_arch = System.getProperty("os.arch");
- String path = sep + "lib" + sep + linkage + sep + os_name + sep + os_arch + sep;
- if (null != flavor)
- path += flavor + sep;
- return path;
- }
-
- /** Full path without the OS name non-unified. */
- private String fatJarLibraryPathNonUnified(String linkage, String flavor) {
- String sep = "/"; //System.getProperty("file.separator");
- String os_name = System.getProperty("os.name");
- String os_arch = System.getProperty("os.arch");
- String path = sep + "lib" + sep + linkage + sep + os_name + sep + os_arch + sep;
- if (null != flavor)
- path += flavor + sep;
- return path;
- }
-
- /** Try to open a file at the given position. */
- private InputStream tryPath(String path) {
- Logger.getLogger().debug("Trying path \"" + path + "\".");
- return getClass().getResourceAsStream(path);
- }
-
- /** Load a system library from a stream. Copies the library to a temp file
- * and loads from there.
- */
- private void loadLibraryFromStream(String libname, InputStream is) {
- try {
- File tempfile = File.createTempFile("jblas", libname);
- tempfile.deleteOnExit();
- OutputStream os = new FileOutputStream(tempfile);
-
- logger.debug("tempfile.getPath() = " + tempfile.getPath());
-
- long savedTime = System.currentTimeMillis();
-
- byte buf[] = new byte[1024];
- int len;
- while ((len = is.read(buf)) > 0) {
- os.write(buf, 0, len);
- }
-
- double seconds = (double) (System.currentTimeMillis() - savedTime) / 1e3;
- logger.debug("Copying took " + seconds + " seconds.");
-
- os.close();
-
- System.load(tempfile.getPath());
- } catch (IOException io) {
- logger.error("Could not create the temp file: " + io.toString() + ".\n");
- } catch (UnsatisfiedLinkError ule) {
- logger.error("Couldn't load copied link file: " + ule.toString() + ".\n");
- }
- }
-}
+// --- BEGIN LICENSE BLOCK ---
+/*
+ * Copyright (c) 2009, Mikio L. Braun
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * * Neither the name of the Technische Universität Berlin nor the
+ * names of its contributors may be used to endorse or promote
+ * products derived from this software without specific prior
+ * written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+// --- END LICENSE BLOCK ---
+package org.jblas.util;
+
+import org.jblas.exceptions.UnsupportedArchitectureException;
+
+import java.io.*;
+
+/**
+ * Class which allows to load a dynamic file as resource (for example, from a
+ * jar-file)
+ */
+public class LibraryLoader {
+
+ private Logger logger;
+ private String libpath;
+
+ public LibraryLoader() {
+ logger = Logger.getLogger();
+ libpath = null;
+ }
+
+ /**
+ * <p>Find the library <tt>libname</tt> as a resource, copy it to a tempfile
+ * and load it using System.load(). The name of the library has to be the
+ * base name, it is mapped to the corresponding system name using
+ * System.mapLibraryName(). For example, the library "foo" is called "libfoo.so"
+ * under Linux and "foo.dll" under Windows, but you just have to pass "foo"
+ * the loadLibrary().</p>
+ * <p/>
+ * <p>I'm not quite sure if this doesn't open all kinds of security holes. Any ideas?</p>
+ * <p/>
+ * <p>This function reports some more information to the "org.jblas" logger at
+ * the FINE level.</p>
+ *
+ * @param libname basename of the library
+ * @throws UnsatisfiedLinkError if library cannot be founds
+ */
+ public void loadLibrary(String libname, boolean withFlavor, boolean noPrefix) {
+ // preload flavor libraries
+ String flavor = null;
+ if (withFlavor) {
+ logger.debug("Preloading ArchFlavor library.");
+ flavor = ArchFlavor.archFlavor();
+ if (flavor != null && flavor.equals("sse2")) {
+ throw new UnsupportedArchitectureException("Support for SSE2 processors stopped with version 1.2.2. Sorry.");
+ }
+ }
+ logger.debug("Found flavor = '" + flavor + "'");
+
+ libname = System.mapLibraryName(libname);
+
+ /*
+ * JDK 7 changed the ending for Mac OS from "jnilib" to "dylib".
+ *
+ * If that is the case, remap the filename.
+ */
+ String loadLibname = libname;
+ if (libname.endsWith("dylib")) {
+ loadLibname = libname.replace(".dylib", ".jnilib");
+ logger.config("Replaced .dylib with .jnilib");
+ }
+
+ logger.debug("Attempting to load \"" + loadLibname + "\".");
+
+ String[] paths = {
+ "/",
+ "/bin/",
+ fatJarLibraryPath("static", flavor),
+ fatJarLibraryPathNonUnified("static", flavor),
+ fatJarLibraryPath("dynamic", flavor),
+ fatJarLibraryPathNonUnified("dynamic", flavor),
+ };
+
+ InputStream is = findLibrary(paths, loadLibname);
+
+ // Oh man, have to get out of here!
+ if (is == null) {
+ throw new UnsatisfiedLinkError("Couldn't find the resource " + loadLibname + ".");
+ }
+
+ logger.config("Loading " + loadLibname + " from " + libpath + ", copying to " + libname + ".");
+ loadLibraryFromStream(libname, is, noPrefix);
+ }
+
+ private InputStream findLibrary(String[] paths, String libname) {
+ InputStream is = null;
+ for (String path : paths) {
+ is = tryPath(path + libname);
+ if (is != null) {
+ logger.debug("Found " + libname + " in " + path);
+ libpath = path;
+ break;
+ }
+ }
+ return is;
+ }
+
+ /**
+ * Translate all those Windows to "Windows". ("Windows XP", "Windows Vista", "Windows 7", etc.)
+ */
+ private String unifyOSName(String osname) {
+ if (osname.startsWith("Windows")) {
+ return "Windows";
+ }
+ return osname;
+ }
+
+ /**
+ * Compute the path to the library. The path is basically
+ * "/" + os.name + "/" + os.arch + "/" + libname.
+ */
+ private String fatJarLibraryPath(String linkage, String flavor) {
+ String sep = "/"; //System.getProperty("file.separator");
+ String os_name = unifyOSName(System.getProperty("os.name"));
+ String os_arch = System.getProperty("os.arch");
+ String path = sep + "lib" + sep + linkage + sep + os_name + sep + os_arch + sep;
+ if (null != flavor)
+ path += flavor + sep;
+ return path;
+ }
+
+ /**
+ * Full path without the OS name non-unified.
+ */
+ private String fatJarLibraryPathNonUnified(String linkage, String flavor) {
+ String sep = "/"; //System.getProperty("file.separator");
+ String os_name = System.getProperty("os.name");
+ String os_arch = System.getProperty("os.arch");
+ String path = sep + "lib" + sep + linkage + sep + os_name + sep + os_arch + sep;
+ if (null != flavor)
+ path += flavor + sep;
+ return path;
+ }
+
+ /**
+ * Try to open a file at the given position.
+ */
+ private InputStream tryPath(String path) {
+ Logger.getLogger().debug("Trying path \"" + path + "\".");
+ return getClass().getResourceAsStream(path);
+ }
+
+ private File createTempFile(String prefix, String suffix, boolean noPrefix) throws IOException {
+ File tempfile = File.createTempFile(prefix, suffix);
+ if (noPrefix == true) {
+ return new File(tempfile.getParentFile(), suffix);
+ } else {
+ return tempfile;
+ }
+ }
+
+ /**
+ * Load a system library from a stream. Copies the library to a temp file
+ * and loads from there.
+ *
+ * @param libname name of the library (just used in constructing the library name)
+ * @param is InputStream pointing to the library
+ */
+ private void loadLibraryFromStream(String libname, InputStream is, boolean noPrefix) {
+ try {
+ File tempfile = createTempFile("jblas", libname, noPrefix);
+ tempfile.deleteOnExit();
+ OutputStream os = new FileOutputStream(tempfile);
+
+ logger.debug("tempfile.getPath() = " + tempfile.getPath());
+
+ long savedTime = System.currentTimeMillis();
+
+ // Leo says 8k block size is STANDARD ;)
+ byte buf[] = new byte[8192];
+ int len;
+ while ((len = is.read(buf)) > 0) {
+ os.write(buf, 0, len);
+ }
+
+ double seconds = (double) (System.currentTimeMillis() - savedTime) / 1e3;
+ logger.debug("Copying took " + seconds + " seconds.");
+
+ os.close();
+
+ logger.debug("Loading library from " + tempfile.getPath() + ".");
+ System.load(tempfile.getPath());
+ } catch (IOException io) {
+ logger.error("Could not create the temp file: " + io.toString() + ".\n");
+ } catch (UnsatisfiedLinkError ule) {
+ logger.error("Couldn't load copied link file: " + ule.toString() + ".\n");
+ throw ule;
+ }
+ }
+}
diff --git a/src/main/java/org/jblas/util/Permutations.java b/src/main/java/org/jblas/util/Permutations.java
index 78373e2..2f6ab76 100644
--- a/src/main/java/org/jblas/util/Permutations.java
+++ b/src/main/java/org/jblas/util/Permutations.java
@@ -38,6 +38,7 @@ package org.jblas.util;
import java.util.Random;
import org.jblas.DoubleMatrix;
+import org.jblas.FloatMatrix;
/**
* Functions which generate random permutations.
@@ -95,7 +96,7 @@ public class Permutations {
*
* @param ipiv row i was interchanged with row ipiv[i]
*/
- public static DoubleMatrix permutationMatrixFromPivotIndices(int size, int[] ipiv) {
+ public static DoubleMatrix permutationDoubleMatrixFromPivotIndices(int size, int[] ipiv) {
int n = ipiv.length;
//System.out.printf("size = %d n = %d\n", size, n);
int indices[] = new int[size];
@@ -116,4 +117,31 @@ public class Permutations {
result.put(indices[i], i, 1.0);
return result;
}
+
+ /**
+ * Create a permutation matrix from a LAPACK-style 'ipiv' vector.
+ *
+ * @param ipiv row i was interchanged with row ipiv[i]
+ */
+ public static FloatMatrix permutationFloatMatrixFromPivotIndices(int size, int[] ipiv) {
+ int n = ipiv.length;
+ //System.out.printf("size = %d n = %d\n", size, n);
+ int indices[] = new int[size];
+ for (int i = 0; i < size; i++)
+ indices[i] = i;
+
+ //for (int i = 0; i < n; i++)
+ // System.out.printf("ipiv[%d] = %d\n", i, ipiv[i]);
+
+ for (int i = 0; i < n; i++) {
+ int j = ipiv[i] - 1;
+ int t = indices[i];
+ indices[i] = indices[j];
+ indices[j] = t;
+ }
+ FloatMatrix result = new FloatMatrix(size, size);
+ for (int i = 0; i < size; i++)
+ result.put(indices[i], i, 1.0f);
+ return result;
+ }
}
diff --git a/src/main/java/org/jblas/util/Random.java b/src/main/java/org/jblas/util/Random.java
new file mode 100644
index 0000000..bb1d1d2
--- /dev/null
+++ b/src/main/java/org/jblas/util/Random.java
@@ -0,0 +1,33 @@
+package org.jblas.util;
+
+/**
+ * Created by IntelliJ IDEA.
+ * User: mikio
+ * Date: 6/24/11
+ * Time: 10:45 AM
+ * To change this template use File | Settings | File Templates.
+ */
+
+public class Random {
+ private static java.util.Random r = new java.util.Random();
+
+ public static void seed(long newSeed) {
+ r = new java.util.Random(newSeed);
+ }
+
+ public static double nextDouble() {
+ return r.nextDouble();
+ }
+
+ public static float nextFloat() {
+ return r.nextFloat();
+ }
+
+ public static int nextInt(int max) {
+ return r.nextInt(max);
+ }
+
+ public static double nextGaussian() {
+ return r.nextGaussian();
+ }
+}
diff --git a/src/main/resources/lib/static/Linux/amd64/libjblas_arch_flavor.so b/src/main/resources/lib/static/Linux/amd64/libjblas_arch_flavor.so
index 7a5d649..59bd9ab 100755
Binary files a/src/main/resources/lib/static/Linux/amd64/libjblas_arch_flavor.so and b/src/main/resources/lib/static/Linux/amd64/libjblas_arch_flavor.so differ
diff --git a/src/main/resources/lib/static/Linux/amd64/sse2/libjblas.so b/src/main/resources/lib/static/Linux/amd64/sse2/libjblas.so
deleted file mode 100755
index 569dfd5..0000000
Binary files a/src/main/resources/lib/static/Linux/amd64/sse2/libjblas.so and /dev/null differ
diff --git a/src/main/resources/lib/static/Linux/amd64/sse3/libjblas.so b/src/main/resources/lib/static/Linux/amd64/sse3/libjblas.so
index 3ddefec..d9e5d37 100755
Binary files a/src/main/resources/lib/static/Linux/amd64/sse3/libjblas.so and b/src/main/resources/lib/static/Linux/amd64/sse3/libjblas.so differ
diff --git a/src/main/resources/lib/static/Linux/i386/libjblas_arch_flavor.so b/src/main/resources/lib/static/Linux/i386/libjblas_arch_flavor.so
index 01180d6..9d61418 100755
Binary files a/src/main/resources/lib/static/Linux/i386/libjblas_arch_flavor.so and b/src/main/resources/lib/static/Linux/i386/libjblas_arch_flavor.so differ
diff --git a/src/main/resources/lib/static/Linux/i386/sse2/libjblas.so b/src/main/resources/lib/static/Linux/i386/sse2/libjblas.so
deleted file mode 100755
index 5425bd3..0000000
Binary files a/src/main/resources/lib/static/Linux/i386/sse2/libjblas.so and /dev/null differ
diff --git a/src/main/resources/lib/static/Linux/i386/sse3/libjblas.so b/src/main/resources/lib/static/Linux/i386/sse3/libjblas.so
index 4f988c6..57e8465 100755
Binary files a/src/main/resources/lib/static/Linux/i386/sse3/libjblas.so and b/src/main/resources/lib/static/Linux/i386/sse3/libjblas.so differ
diff --git a/src/main/resources/lib/static/Mac OS X/x86_64/libjblas_arch_flavor.jnilib b/src/main/resources/lib/static/Mac OS X/x86_64/libjblas_arch_flavor.jnilib
index 57caf20..43cc076 100755
Binary files a/src/main/resources/lib/static/Mac OS X/x86_64/libjblas_arch_flavor.jnilib and b/src/main/resources/lib/static/Mac OS X/x86_64/libjblas_arch_flavor.jnilib differ
diff --git a/src/main/resources/lib/static/Mac OS X/x86_64/sse3/libjblas.jnilib b/src/main/resources/lib/static/Mac OS X/x86_64/sse3/libjblas.jnilib
index ebab2a4..73a0d0a 100755
Binary files a/src/main/resources/lib/static/Mac OS X/x86_64/sse3/libjblas.jnilib and b/src/main/resources/lib/static/Mac OS X/x86_64/sse3/libjblas.jnilib differ
diff --git a/src/main/resources/lib/static/Windows/amd64/jblas.dll b/src/main/resources/lib/static/Windows/amd64/jblas.dll
index 62f8788..93f7b19 100755
Binary files a/src/main/resources/lib/static/Windows/amd64/jblas.dll and b/src/main/resources/lib/static/Windows/amd64/jblas.dll differ
diff --git a/src/main/resources/lib/static/Windows/amd64/jblas_arch_flavor.dll b/src/main/resources/lib/static/Windows/amd64/jblas_arch_flavor.dll
index 99742cf..8651979 100755
Binary files a/src/main/resources/lib/static/Windows/amd64/jblas_arch_flavor.dll and b/src/main/resources/lib/static/Windows/amd64/jblas_arch_flavor.dll differ
diff --git a/src/main/resources/lib/static/Windows/amd64/libgcc_s_sjlj-1.dll b/src/main/resources/lib/static/Windows/amd64/libgcc_s_sjlj-1.dll
new file mode 100644
index 0000000..3a69f1c
Binary files /dev/null and b/src/main/resources/lib/static/Windows/amd64/libgcc_s_sjlj-1.dll differ
diff --git a/src/main/resources/lib/static/Windows/amd64/libgfortran-3.dll b/src/main/resources/lib/static/Windows/amd64/libgfortran-3.dll
new file mode 100644
index 0000000..7d44e83
Binary files /dev/null and b/src/main/resources/lib/static/Windows/amd64/libgfortran-3.dll differ
diff --git a/src/main/resources/lib/static/Windows/x86/jblas_arch_flavor.dll b/src/main/resources/lib/static/Windows/x86/jblas_arch_flavor.dll
index 3844cd3..54ebb25 100755
Binary files a/src/main/resources/lib/static/Windows/x86/jblas_arch_flavor.dll and b/src/main/resources/lib/static/Windows/x86/jblas_arch_flavor.dll differ
diff --git a/src/main/resources/lib/static/Windows/x86/libgcc_s_dw2-1.dll b/src/main/resources/lib/static/Windows/x86/libgcc_s_dw2-1.dll
new file mode 100644
index 0000000..07da32a
Binary files /dev/null and b/src/main/resources/lib/static/Windows/x86/libgcc_s_dw2-1.dll differ
diff --git a/src/main/resources/lib/static/Windows/x86/libgfortran-3.dll b/src/main/resources/lib/static/Windows/x86/libgfortran-3.dll
new file mode 100644
index 0000000..418d4c6
Binary files /dev/null and b/src/main/resources/lib/static/Windows/x86/libgfortran-3.dll differ
diff --git a/src/main/resources/lib/static/Windows/x86/sse2/jblas.dll b/src/main/resources/lib/static/Windows/x86/sse2/jblas.dll
deleted file mode 100755
index 31f96eb..0000000
Binary files a/src/main/resources/lib/static/Windows/x86/sse2/jblas.dll and /dev/null differ
diff --git a/src/main/resources/lib/static/Windows/x86/sse3/jblas.dll b/src/main/resources/lib/static/Windows/x86/sse3/jblas.dll
index f851925..eaeb765 100755
Binary files a/src/main/resources/lib/static/Windows/x86/sse3/jblas.dll and b/src/main/resources/lib/static/Windows/x86/sse3/jblas.dll differ
diff --git a/src/test/java/org/jblas/ComplexDoubleMatrixTest.java b/src/test/java/org/jblas/ComplexDoubleMatrixTest.java
index ee32c1c..d940792 100644
--- a/src/test/java/org/jblas/ComplexDoubleMatrixTest.java
+++ b/src/test/java/org/jblas/ComplexDoubleMatrixTest.java
@@ -41,43 +41,30 @@
package org.jblas;
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import junit.framework.TestCase;
+import org.junit.*;
+import static org.junit.Assert.*;
/**
*
* @author mikio
*/
-public class ComplexDoubleMatrixTest extends TestCase {
-
- public ComplexDoubleMatrixTest(String testName) {
- super(testName);
- }
-
- @Override
- protected void setUp() throws Exception {
- super.setUp();
- }
-
- @Override
- protected void tearDown() throws Exception {
- super.tearDown();
- }
+public class ComplexDoubleMatrixTest {
+
+ @Test
public void testConstruction() {
ComplexDoubleMatrix A = new ComplexDoubleMatrix(3, 3);
for (int i = 0; i < A.rows; i++)
for (int j = 0; j < A.columns; j++)
A.put(i, j, new ComplexDouble(i, j));
- System.out.printf("A = %s\n", A.toString());
+ //System.out.printf("A = %s\n", A.toString());
- System.out.println(A.mmul(A));
+ //System.out.println(A.mmul(A));
DoubleMatrix R = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
A = new ComplexDoubleMatrix(R, R.transpose());
- System.out.println(A);
+ //System.out.println(A);
assertEquals(A.real(), R);
assertEquals(A.imag(), R.transpose());
diff --git a/src/main/java/org/jblas/FloatFunction.java b/src/test/java/org/jblas/JblasAssert.java
similarity index 80%
copy from src/main/java/org/jblas/FloatFunction.java
copy to src/test/java/org/jblas/JblasAssert.java
index 23cc618..82010a2 100644
--- a/src/main/java/org/jblas/FloatFunction.java
+++ b/src/test/java/org/jblas/JblasAssert.java
@@ -1,25 +1,25 @@
// --- BEGIN LICENSE BLOCK ---
-/*
- * Copyright (c) 2009, Mikio L. Braun
+/*
+ * Copyright (c) 2012, Mikio L. Braun
* All rights reserved.
- *
+ *
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
- *
+ *
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
- *
+ *
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials provided
* with the distribution.
- *
+ *
* * Neither the name of the Technische Universität Berlin nor the
* names of its contributors may be used to endorse or promote
* products derived from this software without specific prior
* written permission.
- *
+ *
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
@@ -36,10 +36,12 @@
package org.jblas;
-/**
- * Represents a function on floats.
- */
-public interface FloatFunction {
- /** Compute the function. */
- public float compute(float x);
+import static org.junit.Assert.assertArrayEquals;
+
+public class JblasAssert {
+ public static void assertEquals(DoubleMatrix expected, DoubleMatrix actual) {
+ org.junit.Assert.assertEquals(expected.rows, actual.rows);
+ org.junit.Assert.assertEquals(expected.columns, actual.columns);
+ assertArrayEquals(expected.data, actual.data, 1e-12);
+ }
}
diff --git a/src/test/java/org/jblas/SimpleBlasTest.java b/src/test/java/org/jblas/SimpleBlasTest.java
index b2fb358..341bb52 100644
--- a/src/test/java/org/jblas/SimpleBlasTest.java
+++ b/src/test/java/org/jblas/SimpleBlasTest.java
@@ -34,25 +34,20 @@
*/
// --- END LICENSE BLOCK ---
-/*
- * To change this template, choose Tools | Templates
- * and open the template in the editor.
- */
-
package org.jblas;
-import junit.framework.TestCase;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
/**
+ * Some test for class SimpleBlas
*
- * @author mikio
+ * @author Mikio L. Braun
*/
-public class SimpleBlasTest extends TestCase {
+public class SimpleBlasTest {
- public SimpleBlasTest(String testName) {
- super(testName);
- }
-
+ @Test
public void testGeev() {
DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0);
DoubleMatrix WR = new DoubleMatrix(2);
@@ -65,11 +60,11 @@ public class SimpleBlasTest extends TestCase {
assertEquals(new DoubleMatrix(2, 1, 2.0, 2.0), WR);
assertEquals(new DoubleMatrix(2, 1, Math.sqrt(2.0), -Math.sqrt(2.0)), WI);
- System.out.printf("WR = %s\n", WR.toString());
+ /*System.out.printf("WR = %s\n", WR.toString());
System.out.printf("WI = %s\n", WI.toString());
System.out.printf("VR = %s\n", VR.toString());
System.out.printf("VL = %s\n", VL.toString());
- System.out.printf("A = %s\n", A.toString());
+ System.out.printf("A = %s\n", A.toString());*/
}
}
diff --git a/src/test/java/org/jblas/TestBlasDouble.java b/src/test/java/org/jblas/TestBlasDouble.java
index 565a475..d831716 100644
--- a/src/test/java/org/jblas/TestBlasDouble.java
+++ b/src/test/java/org/jblas/TestBlasDouble.java
@@ -36,148 +36,166 @@
package org.jblas;
-import junit.framework.TestCase;
-
-import static org.jblas.MatrixFunctions.*;
-
-public class TestBlasDouble extends TestCase {
-
- /** test sum of absolute values */
- public void testAsum() {
- double[] a = new double[]{1.0, 2.0, 3.0, 4.0};
-
- assertEquals(10.0, NativeBlas.dasum(4, a, 0, 1));
- assertEquals(4.0, NativeBlas.dasum(2, a, 0, 2));
- assertEquals(5.0, NativeBlas.dasum(2, a, 1, 1));
- }
-
- /** test scalar product */
- public void testDot() {
- double[] a = new double[] { 1.0, 2.0, 3.0, 4.0 };
- double[] b = new double[] { 4.0, 5.0, 6.0, 7.0 };
-
- assertEquals(32.0, NativeBlas.ddot(3, a, 0, 1, b, 0, 1));
- assertEquals(22.0, NativeBlas.ddot(2, a, 0, 2, b, 0, 2));
- assertEquals(5.0 + 12.0 + 21.0, NativeBlas.ddot(3, a, 0, 1, b, 1, 1));
- }
-
- public void testSwap() {
- double[] a = new double[] { 1.0, 2.0, 3.0, 4.0 };
- double[] b = new double[] { 4.0, 5.0, 6.0, 7.0 };
- double[] c = new double[] { 1.0, 2.0, 3.0, 4.0 };
- double[] d = new double[] { 4.0, 5.0, 6.0, 7.0 };
-
- System.out.println("dswap");
- NativeBlas.dswap(4, a, 0, 1, b, 0, 1);
- assertTrue(arraysEqual(a, d));
- assertTrue(arraysEqual(b, c));
-
- System.out.println("dswap same");
- NativeBlas.dswap(2, a, 0, 2, a, 1, 2);
- assertTrue(arraysEqual(a, 5.0, 4.0, 7.0, 6.0));
- }
-
- /* test vector addition */
- public void testAxpy() {
- double[] x = new double[] { 1.0, 2.0, 3.0, 4.0 };
- double[] y = new double[] { 0.0, 0.0, 0.0, 0.0 };
-
- NativeBlas.daxpy(4, 2.0, x, 0, 1, y, 0, 1);
-
- for(int i = 0; i < 4; i++)
- assertEquals(2*x[i], y[i]);
- }
-
- /* test matric-vector multiplication */
- public void testGemv() {
- double[] A = new double[] { 1.0, 2.0, 3.0,
- 4.0, 5.0, 6.0,
- 7.0, 8.0, 9.0 };
-
- double[] x = new double[] {1.0, 3.0, 7.0 };
- double[] y = new double[] { 0.0, 0.0, 0.0 };
-
- NativeBlas.dgemv('N', 3, 3, 1.0, A, 0, 3, x, 0, 1, 0.0, y, 0, 1);
-
- //printMatrix(3, 3, A);
- //printMatrix(3, 1, x);
- //printMatrix(3, 1, y);
-
- assertTrue(arraysEqual(y, 62.0, 73.0, 84.0));
-
- NativeBlas.dgemv('T', 3, 3, 1.0, A, 0, 3, x, 0, 1, 0.5, y, 0, 1);
-
- //printMatrix(3, 1, y);
- assertTrue(arraysEqual(y, 59.0, 97.5, 136.0));
- }
-
- /** Compare double buffer against an array of doubles */
- private boolean arraysEqual(double[] a, double... b) {
- if (a.length != b.length)
- return false;
- else {
- double diff = 0.0;
- for (int i = 0; i < b.length; i++)
- diff += abs(a[i] - b[i]);
- return diff < 1e-6;
- }
- }
-
- public static void main(String[] args) {
- TestBlasDouble t = new TestBlasDouble();
-
- t.testAsum();
- }
-
- public static void testSolve() {
- DoubleMatrix A = new DoubleMatrix(3, 3, 3.0, 5.0, 6.0, 1.0, 0.0, 0.0, 2.0, 4.0, 0.0);
- DoubleMatrix X = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
- int[] p = new int[3];
- SimpleBlas.gesv(A, p, X);
- A.print();
- X.print();
- // De-shuffle X
- for (int i = 2; i >= 0; i--) {
- int perm = p[i] - 1;
- double t = X.get(i); X.put(i, X.get(perm)); X.put(perm, t);
- }
- System.out.println();
- X.print();
- }
-
- public static void testSymmetricSolve() {
- System.out.println("--- Symmetric solve");
- DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
- DoubleMatrix x = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
- int[] p = new int[3];
- SimpleBlas.sysv('U', A, p, x);
- A.print();
- x.print();
- }
-
- public static void testSYEV() {
- System.out.println("--- Symmetric eigenvalues");
- int n = 10;
- DoubleMatrix x = DoubleMatrix.randn(n).sort();
-
- //DoubleMatrix A = new DoubleMatrix(new double[][] {{1.0, 0.5, 0.1}, {0.5, 1.0, 0.5}, {0.1, 0.5, 1.0}});
- DoubleMatrix A = expi(Geometry.pairwiseSquaredDistances(x, x).muli(-2.0));
- DoubleMatrix w = new DoubleMatrix(n);
-
- DoubleMatrix B = A.dup();
- System.out.println("Computing eigenvalues with SYEV");
- SimpleBlas.syev('V', 'U', B, w);
- System.out.println("Eigenvalues: ");
- w.print();
- System.out.println("Eigenvectors: ");
- B.print();
-
- B = A.dup();
- System.out.println("Computing eigenvalues with SYEVD");
- SimpleBlas.syevd('V', 'U', B, w);
- System.out.println("Eigenvalues: ");
- w.print();
- System.out.println("Eigenvectors: ");
- B.print();
- }
+import org.junit.Test;
+
+import static org.jblas.MatrixFunctions.abs;
+import static org.jblas.MatrixFunctions.expi;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class TestBlasDouble {
+
+ /**
+ * test sum of absolute values
+ */
+ @Test
+ public void testAsum() {
+ double[] a = new double[]{1.0, 2.0, 3.0, 4.0};
+
+ assertEquals(10.0, NativeBlas.dasum(4, a, 0, 1), 1e-6);
+ assertEquals(4.0, NativeBlas.dasum(2, a, 0, 2), 1e-6);
+ assertEquals(5.0, NativeBlas.dasum(2, a, 1, 1), 1e-6);
+ }
+
+ /**
+ * test scalar product
+ */
+ @Test
+ public void testDot() {
+ double[] a = new double[]{1.0, 2.0, 3.0, 4.0};
+ double[] b = new double[]{4.0, 5.0, 6.0, 7.0};
+
+ assertEquals(32.0, NativeBlas.ddot(3, a, 0, 1, b, 0, 1), 1e-6);
+ assertEquals(22.0, NativeBlas.ddot(2, a, 0, 2, b, 0, 2), 1e-6);
+ assertEquals(5.0 + 12.0 + 21.0, NativeBlas.ddot(3, a, 0, 1, b, 1, 1), 1e-6);
+ }
+
+ @Test
+ public void testSwap() {
+ double[] a = new double[]{1.0, 2.0, 3.0, 4.0};
+ double[] b = new double[]{4.0, 5.0, 6.0, 7.0};
+ double[] c = new double[]{1.0, 2.0, 3.0, 4.0};
+ double[] d = new double[]{4.0, 5.0, 6.0, 7.0};
+
+ NativeBlas.dswap(4, a, 0, 1, b, 0, 1);
+ assertTrue(arraysEqual(a, d));
+ assertTrue(arraysEqual(b, c));
+
+ NativeBlas.dswap(2, a, 0, 2, a, 1, 2);
+ assertTrue(arraysEqual(a, 5.0, 4.0, 7.0, 6.0));
+ }
+
+ /* test vector addition */
+ @Test
+ public void testAxpy() {
+ double[] x = new double[]{1.0, 2.0, 3.0, 4.0};
+ double[] y = new double[]{0.0, 0.0, 0.0, 0.0};
+
+ NativeBlas.daxpy(4, 2.0, x, 0, 1, y, 0, 1);
+
+ for (int i = 0; i < 4; i++)
+ assertEquals(2 * x[i], y[i], 1e-6);
+ }
+
+ /* test matric-vector multiplication */
+ @Test
+ public void testGemv() {
+ double[] A = new double[]{1.0, 2.0, 3.0,
+ 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0};
+
+ double[] x = new double[]{1.0, 3.0, 7.0};
+ double[] y = new double[]{0.0, 0.0, 0.0};
+
+ NativeBlas.dgemv('N', 3, 3, 1.0, A, 0, 3, x, 0, 1, 0.0, y, 0, 1);
+
+ //printMatrix(3, 3, A);
+ //printMatrix(3, 1, x);
+ //printMatrix(3, 1, y);
+
+ assertTrue(arraysEqual(y, 62.0, 73.0, 84.0));
+
+ NativeBlas.dgemv('T', 3, 3, 1.0, A, 0, 3, x, 0, 1, 0.5, y, 0, 1);
+
+ //printMatrix(3, 1, y);
+ assertTrue(arraysEqual(y, 59.0, 97.5, 136.0));
+ }
+
+ /**
+ * Compare double buffer against an array of doubles
+ */
+ private boolean arraysEqual(double[] a, double... b) {
+ if (a.length != b.length)
+ return false;
+ else {
+ double diff = 0.0;
+ for (int i = 0; i < b.length; i++)
+ diff += abs(a[i] - b[i]);
+ return diff < 1e-6 * a.length;
+ }
+ }
+
+ @Test
+ public void testSolve() {
+ DoubleMatrix A = new DoubleMatrix(3, 3, 3.0, 5.0, 6.0, 1.0, 0.0, 0.0, 2.0, 4.0, 0.0);
+ DoubleMatrix X = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
+ int[] p = new int[3];
+ SimpleBlas.gesv(A, p, X);
+ //A.print();
+ //X.print();
+
+ // De-shuffle X
+ for (int i = 2; i >= 0; i--) {
+ int perm = p[i] - 1;
+ double t = X.get(i);
+ X.put(i, X.get(perm));
+ X.put(perm, t);
+ }
+
+ //X.print();
+ assertTrue(arraysEqual(X.data, -0.25, -0.125, 0.5));
+ }
+
+ @Test
+ public void testSymmetricSolve() {
+ //System.out.println("--- Symmetric solve");
+ DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+ DoubleMatrix x = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
+ int[] p = new int[3];
+ SimpleBlas.sysv('U', A, p, x);
+ //A.print();
+ //x.print();
+ assertTrue(arraysEqual(x.data, 0.3, 0.0, 0.1));
+ }
+
+ @Test
+ public void testSYEV() {
+ /* From R:
+ > x <- matrix(c(1,0.5,0.1,0.5,1.0, 0.5, 0.1, 0.5, 1.0), 3, 3)
+ > eigen(x)
+ $values
+ [1] 1.7588723 0.9000000 0.3411277
+
+ $vectors
+ [,1] [,2] [,3]
+ [1,] -0.5173332 -7.071068e-01 0.4820439
+ [2,] -0.6817131 -7.182297e-16 -0.7316196
+ [3,] -0.5173332 7.071068e-01 0.4820439
+ */
+
+ DoubleMatrix A = new DoubleMatrix(new double[][]{{1.0, 0.5, 0.1}, {0.5, 1.0, 0.5}, {0.1, 0.5, 1.0}});
+ DoubleMatrix w = new DoubleMatrix(3);
+
+ DoubleMatrix B = A.dup();
+ SimpleBlas.syev('V', 'U', B, w);
+
+ assertTrue(arraysEqual(w.data, 0.34112765606210876, 0.9, 1.7588723439378915));
+ assertTrue(arraysEqual(B.data, -0.48204393949466345, 0.731619628490741, -0.482043939494664, -0.7071067811865474, 1.3877787807814457E-16, 0.707106781186547, 0.5173332005549852, 0.6817130768931094, 0.5173332005549856));
+
+ B = A.dup();
+ SimpleBlas.syevd('V', 'U', B, w);
+
+ assertTrue(arraysEqual(w.data, 0.34112765606210876, 0.9, 1.7588723439378915));
+ assertTrue(arraysEqual(B.data, -0.48204393949466345, 0.731619628490741, -0.482043939494664, -0.7071067811865474, 1.3877787807814457E-16, 0.707106781186547, 0.5173332005549852, 0.6817130768931094, 0.5173332005549856));
+ }
}
diff --git a/src/test/java/org/jblas/TestBlasDoubleComplex.java b/src/test/java/org/jblas/TestBlasDoubleComplex.java
index 1708ee7..6dc3a1e 100644
--- a/src/test/java/org/jblas/TestBlasDoubleComplex.java
+++ b/src/test/java/org/jblas/TestBlasDoubleComplex.java
@@ -35,20 +35,23 @@
// --- END LICENSE BLOCK ---
package org.jblas;
-import junit.framework.TestCase;
+import org.junit.*;
+import static org.junit.Assert.*;
-public class TestBlasDoubleComplex extends TestCase {
+public class TestBlasDoubleComplex {
+ @Test
public void testZCOPY() {
double[] a = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
double[] b = new double[6];
NativeBlas.zcopy(3, a, 0, 1, b, 0, 1);
for (int i = 0; i < 6; i++) {
- assertEquals((double)(i+1), b[i]);
+ assertEquals((double)(i+1), b[i], 1e-6);
}
}
+ @Test
public void testZDOTU() {
double[] a = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
@@ -56,6 +59,7 @@ public class TestBlasDoubleComplex extends TestCase {
assertEquals(new ComplexDouble(-21.0, 88.0), c);
}
+ @Test
public void testAxpy() {
double[] x = {0.0, -1.0};
double[] y = {0.0, 1.0};
diff --git a/src/test/java/org/jblas/TestBlasFloat.java b/src/test/java/org/jblas/TestBlasFloat.java
index 89224f7..5c2d4c8 100644
--- a/src/test/java/org/jblas/TestBlasFloat.java
+++ b/src/test/java/org/jblas/TestBlasFloat.java
@@ -36,148 +36,166 @@
package org.jblas;
-import junit.framework.TestCase;
-
-import static org.jblas.MatrixFunctions.*;
-
-public class TestBlasFloat extends TestCase {
-
- /** test sum of absolute values */
- public void testAsum() {
- float[] a = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
-
- assertEquals(10.0f, NativeBlas.sasum(4, a, 0, 1));
- assertEquals(4.0f, NativeBlas.sasum(2, a, 0, 2));
- assertEquals(5.0f, NativeBlas.sasum(2, a, 1, 1));
- }
-
- /** test scalar product */
- public void testDot() {
- float[] a = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
- float[] b = new float[] { 4.0f, 5.0f, 6.0f, 7.0f };
-
- assertEquals(32.0f, NativeBlas.sdot(3, a, 0, 1, b, 0, 1));
- assertEquals(22.0f, NativeBlas.sdot(2, a, 0, 2, b, 0, 2));
- assertEquals(5.0f + 12.0f + 21.0f, NativeBlas.sdot(3, a, 0, 1, b, 1, 1));
- }
-
- public void testSwap() {
- float[] a = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
- float[] b = new float[] { 4.0f, 5.0f, 6.0f, 7.0f };
- float[] c = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
- float[] d = new float[] { 4.0f, 5.0f, 6.0f, 7.0f };
-
- System.out.println("dswap");
- NativeBlas.sswap(4, a, 0, 1, b, 0, 1);
- assertTrue(arraysEqual(a, d));
- assertTrue(arraysEqual(b, c));
-
- System.out.println("dswap same");
- NativeBlas.sswap(2, a, 0, 2, a, 1, 2);
- assertTrue(arraysEqual(a, 5.0f, 4.0f, 7.0f, 6.0f));
- }
-
- /* test vector addition */
- public void testAxpy() {
- float[] x = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
- float[] y = new float[] { 0.0f, 0.0f, 0.0f, 0.0f };
-
- NativeBlas.saxpy(4, 2.0f, x, 0, 1, y, 0, 1);
-
- for(int i = 0; i < 4; i++)
- assertEquals(2*x[i], y[i]);
- }
-
- /* test matric-vector multiplication */
- public void testGemv() {
- float[] A = new float[] { 1.0f, 2.0f, 3.0f,
- 4.0f, 5.0f, 6.0f,
- 7.0f, 8.0f, 9.0f };
-
- float[] x = new float[] {1.0f, 3.0f, 7.0f };
- float[] y = new float[] { 0.0f, 0.0f, 0.0f };
-
- NativeBlas.sgemv('N', 3, 3, 1.0f, A, 0, 3, x, 0, 1, 0.0f, y, 0, 1);
-
- //printMatrix(3, 3, A);
- //printMatrix(3, 1, x);
- //printMatrix(3, 1, y);
-
- assertTrue(arraysEqual(y, 62.0f, 73.0f, 84.0f));
-
- NativeBlas.sgemv('T', 3, 3, 1.0f, A, 0, 3, x, 0, 1, 0.5f, y, 0, 1);
-
- //printMatrix(3, 1, y);
- assertTrue(arraysEqual(y, 59.0f, 97.5f, 136.0f));
- }
-
- /** Compare float buffer against an array of floats */
- private boolean arraysEqual(float[] a, float... b) {
- if (a.length != b.length)
- return false;
- else {
- float diff = 0.0f;
- for (int i = 0; i < b.length; i++)
- diff += abs(a[i] - b[i]);
- return diff < 1e-6;
- }
- }
-
- public static void main(String[] args) {
- TestBlasFloat t = new TestBlasFloat();
-
- t.testAsum();
- }
-
- public static void testSolve() {
- FloatMatrix A = new FloatMatrix(3, 3, 3.0f, 5.0f, 6.0f, 1.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f);
- FloatMatrix X = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
- int[] p = new int[3];
- SimpleBlas.gesv(A, p, X);
- A.print();
- X.print();
- // De-shuffle X
- for (int i = 2; i >= 0; i--) {
- int perm = p[i] - 1;
- float t = X.get(i); X.put(i, X.get(perm)); X.put(perm, t);
- }
- System.out.println();
- X.print();
- }
-
- public static void testSymmetricSolve() {
- System.out.println("--- Symmetric solve");
- FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
- FloatMatrix x = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
- int[] p = new int[3];
- SimpleBlas.sysv('U', A, p, x);
- A.print();
- x.print();
- }
-
- public static void testSYEV() {
- System.out.println("--- Symmetric eigenvalues");
- int n = 10;
- FloatMatrix x = FloatMatrix.randn(n).sort();
-
- //FloatMatrix A = new FloatMatrix(new float[][] {{1.0f, 0.5f, 0.1f}, {0.5f, 1.0f, 0.5f}, {0.1f, 0.5f, 1.0f}});
- FloatMatrix A = expi(Geometry.pairwiseSquaredDistances(x, x).muli(-2.0f));
- FloatMatrix w = new FloatMatrix(n);
-
- FloatMatrix B = A.dup();
- System.out.println("Computing eigenvalues with SYEV");
- SimpleBlas.syev('V', 'U', B, w);
- System.out.println("Eigenvalues: ");
- w.print();
- System.out.println("Eigenvectors: ");
- B.print();
-
- B = A.dup();
- System.out.println("Computing eigenvalues with SYEVD");
- SimpleBlas.syevd('V', 'U', B, w);
- System.out.println("Eigenvalues: ");
- w.print();
- System.out.println("Eigenvectors: ");
- B.print();
- }
-}
+import org.junit.Test;
+
+import static org.jblas.MatrixFunctions.abs;
+import static org.jblas.MatrixFunctions.expi;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class TestBlasFloat {
+
+ /**
+ * test sum of absolute values
+ */
+ @Test
+ public void testAsum() {
+ float[] a = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
+
+ assertEquals(10.0f, NativeBlas.sasum(4, a, 0, 1), 1e-6);
+ assertEquals(4.0f, NativeBlas.sasum(2, a, 0, 2), 1e-6);
+ assertEquals(5.0f, NativeBlas.sasum(2, a, 1, 1), 1e-6);
+ }
+
+ /**
+ * test scalar product
+ */
+ @Test
+ public void testDot() {
+ float[] a = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
+ float[] b = new float[]{4.0f, 5.0f, 6.0f, 7.0f};
+
+ assertEquals(32.0f, NativeBlas.sdot(3, a, 0, 1, b, 0, 1), 1e-6);
+ assertEquals(22.0f, NativeBlas.sdot(2, a, 0, 2, b, 0, 2), 1e-6);
+ assertEquals(5.0f + 12.0f + 21.0f, NativeBlas.sdot(3, a, 0, 1, b, 1, 1), 1e-6);
+ }
+
+ @Test
+ public void testSwap() {
+ float[] a = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
+ float[] b = new float[]{4.0f, 5.0f, 6.0f, 7.0f};
+ float[] c = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
+ float[] d = new float[]{4.0f, 5.0f, 6.0f, 7.0f};
+
+ NativeBlas.sswap(4, a, 0, 1, b, 0, 1);
+ assertTrue(arraysEqual(a, d));
+ assertTrue(arraysEqual(b, c));
+
+ NativeBlas.sswap(2, a, 0, 2, a, 1, 2);
+ assertTrue(arraysEqual(a, 5.0f, 4.0f, 7.0f, 6.0f));
+ }
+
+ /* test vector addition */
+ @Test
+ public void testAxpy() {
+ float[] x = new float[]{1.0f, 2.0f, 3.0f, 4.0f};
+ float[] y = new float[]{0.0f, 0.0f, 0.0f, 0.0f};
+
+ NativeBlas.saxpy(4, 2.0f, x, 0, 1, y, 0, 1);
+
+ for (int i = 0; i < 4; i++)
+ assertEquals(2 * x[i], y[i], 1e-6);
+ }
+
+ /* test matric-vector multiplication */
+ @Test
+ public void testGemv() {
+ float[] A = new float[]{1.0f, 2.0f, 3.0f,
+ 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f};
+
+ float[] x = new float[]{1.0f, 3.0f, 7.0f};
+ float[] y = new float[]{0.0f, 0.0f, 0.0f};
+
+ NativeBlas.sgemv('N', 3, 3, 1.0f, A, 0, 3, x, 0, 1, 0.0f, y, 0, 1);
+
+ //printMatrix(3, 3, A);
+ //printMatrix(3, 1, x);
+ //printMatrix(3, 1, y);
+
+ assertTrue(arraysEqual(y, 62.0f, 73.0f, 84.0f));
+
+ NativeBlas.sgemv('T', 3, 3, 1.0f, A, 0, 3, x, 0, 1, 0.5f, y, 0, 1);
+
+ //printMatrix(3, 1, y);
+ assertTrue(arraysEqual(y, 59.0f, 97.5f, 136.0f));
+ }
+
+ /**
+ * Compare float buffer against an array of floats
+ */
+ private boolean arraysEqual(float[] a, float... b) {
+ if (a.length != b.length)
+ return false;
+ else {
+ float diff = 0.0f;
+ for (int i = 0; i < b.length; i++)
+ diff += abs(a[i] - b[i]);
+ return diff < 1e-6 * a.length;
+ }
+ }
+
+ @Test
+ public void testSolve() {
+ FloatMatrix A = new FloatMatrix(3, 3, 3.0f, 5.0f, 6.0f, 1.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f);
+ FloatMatrix X = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
+ int[] p = new int[3];
+ SimpleBlas.gesv(A, p, X);
+ //A.print();
+ //X.print();
+
+ // De-shuffle X
+ for (int i = 2; i >= 0; i--) {
+ int perm = p[i] - 1;
+ float t = X.get(i);
+ X.put(i, X.get(perm));
+ X.put(perm, t);
+ }
+
+ //X.print();
+ assertTrue(arraysEqual(X.data, -0.25f, -0.125f, 0.5f));
+ }
+
+ @Test
+ public void testSymmetricSolve() {
+ //System.out.println("--- Symmetric solve");
+ FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+ FloatMatrix x = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
+ int[] p = new int[3];
+ SimpleBlas.sysv('U', A, p, x);
+ //A.print();
+ //x.print();
+ assertTrue(arraysEqual(x.data, 0.3f, 0.0f, 0.1f));
+ }
+
+ @Test
+ public void testSYEV() {
+ /* From R:
+ > x <- matrix(c(1,0.5f,0.1f,0.5f,1.0f, 0.5f, 0.1f, 0.5f, 1.0f), 3, 3)
+ > eigen(x)
+ $values
+ [1] 1.7588723f 0.9000000f 0.3411277f
+
+ $vectors
+ [,1] [,2] [,3]
+ [1,] -0.5173332f -7.071068e-01f 0.4820439f
+ [2,] -0.6817131f -7.182297e-16f -0.7316196f
+ [3,] -0.5173332f 7.071068e-01f 0.4820439f
+ */
+
+ FloatMatrix A = new FloatMatrix(new float[][]{{1.0f, 0.5f, 0.1f}, {0.5f, 1.0f, 0.5f}, {0.1f, 0.5f, 1.0f}});
+ FloatMatrix w = new FloatMatrix(3);
+
+ FloatMatrix B = A.dup();
+ SimpleBlas.syev('V', 'U', B, w);
+
+ assertTrue(arraysEqual(w.data, 0.34112765606210876f, 0.9f, 1.7588723439378915f));
+ assertTrue(arraysEqual(B.data, -0.48204393949466345f, 0.731619628490741f, -0.482043939494664f, -0.7071067811865474f, 1.3877787807814457E-16f, 0.707106781186547f, 0.5173332005549852f, 0.6817130768931094f, 0.5173332005549856f));
+
+ B = A.dup();
+ SimpleBlas.syevd('V', 'U', B, w);
+
+ assertTrue(arraysEqual(w.data, 0.34112765606210876f, 0.9f, 1.7588723439378915f));
+ assertTrue(arraysEqual(B.data, -0.48204393949466345f, 0.731619628490741f, -0.482043939494664f, -0.7071067811865474f, 1.3877787807814457E-16f, 0.707106781186547f, 0.5173332005549852f, 0.6817130768931094f, 0.5173332005549856f));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/jblas/TestComplexFloat.java b/src/test/java/org/jblas/TestComplexFloat.java
index bbbc5f5..1314f9e 100644
--- a/src/test/java/org/jblas/TestComplexFloat.java
+++ b/src/test/java/org/jblas/TestComplexFloat.java
@@ -36,33 +36,40 @@
package org.jblas;
-import junit.framework.TestCase;
+import org.junit.Before;
+import org.junit.Test;
-public class TestComplexFloat extends TestCase {
- public TestComplexFloat() {
- }
+import static org.junit.Assert.assertEquals;
+
+public class TestComplexFloat {
private ComplexFloat a, b;
-
+
+ private final double eps = 1e-16;
+
+ @Before
public void setUp() {
a = new ComplexFloat(1, 2);
b = new ComplexFloat(3, 4);
}
-
+
+ @Test
public void testAdd() {
ComplexFloat c = a.add(b);
- assertEquals(4.0f, c.real());
- assertEquals(6.0f, c.imag());
+ assertEquals(4.0f, c.real(), eps);
+ assertEquals(6.0f, c.imag(), eps);
}
+ @Test
public void testMul() {
ComplexFloat c = a.mul(b);
- assertEquals(-5.0f, c.real());
- assertEquals(10.0f, c.imag());
+ assertEquals(-5.0f, c.real(), eps);
+ assertEquals(10.0f, c.imag(), eps);
}
-
+
+ @Test
public void testMulAndDiv() {
ComplexFloat d = a.mul(b).div(b);
@@ -72,7 +79,8 @@ public class TestComplexFloat extends TestCase {
assertEquals(new ComplexFloat(1.0f, 2.0f), d);
}
-
+
+ @Test
public void testDivByZero() {
a.div(new ComplexFloat(0.0f, 0.0f));
}
diff --git a/src/test/java/org/jblas/TestDecompose.java b/src/test/java/org/jblas/TestDecompose.java
new file mode 100644
index 0000000..d8ab3d9
--- /dev/null
+++ b/src/test/java/org/jblas/TestDecompose.java
@@ -0,0 +1,94 @@
+package org.jblas;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+
+/**
+ * Test class for Decompose
+ *
+ * @author Mikio Braun
+ */
+public class TestDecompose {
+ @Test
+ public void luDouble() {
+ DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+
+ Decompose.LUDecomposition<DoubleMatrix> lu = Decompose.lu(A);
+
+ assertEquals(0.0, (lu.p.mmul(lu.l).mmul(lu.u).sub(A).normmax()), 1e-10);
+
+ assertTrue(lu.l.isLowerTriangular());
+ assertTrue(lu.u.isUpperTriangular());
+ }
+
+ @Test
+ public void luFloat() {
+ FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+
+ Decompose.LUDecomposition<FloatMatrix> lu = Decompose.lu(A);
+
+ assertEquals(0.0f, (lu.p.mmul(lu.l).mmul(lu.u).sub(A).normmax()), 1e-6f);
+
+ assertTrue(lu.l.isLowerTriangular());
+ assertTrue(lu.u.isUpperTriangular());
+ }
+
+ @Test
+ public void qrDouble() {
+ DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+
+ Decompose.QRDecomposition<DoubleMatrix> qr = Decompose.qr(A);
+
+ assertEquals(0.0, DoubleMatrix.eye(3).sub(qr.q.transpose().mmul(qr.q)).normmax(), 1e-10);
+ assertTrue(qr.r.isUpperTriangular());
+ assertEquals(0.0, A.sub(qr.q.mmul(qr.r)).normmax(), 1e-10);
+ }
+
+ @Test
+ public void qrRectangularDouble() {
+ DoubleMatrix A = new DoubleMatrix(2, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
+
+ Decompose.QRDecomposition<DoubleMatrix> qr = Decompose.qr(A);
+
+ assertEquals(0.0, DoubleMatrix.eye(2).sub(qr.q.transpose().mmul(qr.q)).normmax(), 1e-10);
+ assertTrue(qr.r.isUpperTriangular());
+ assertEquals(0.0, A.sub(qr.q.mmul(qr.r)).normmax(), 1e-10);
+ }
+
+ @Test
+ public void qrRectangular2Double() {
+ DoubleMatrix A = new DoubleMatrix(4, 2, 1.0, 2.0, 2.5, 3.0, 4.0, 5.0, 6.0, 6.5);
+
+ Decompose.QRDecomposition<DoubleMatrix> qr = Decompose.qr(A);
+
+ DoubleMatrix qtq = qr.q.transpose().mmul(qr.q);
+
+ assertEquals(0.0, DoubleMatrix.eye(4).sub(qtq).normmax(), 1e-10);
+ assertTrue(qr.r.isUpperTriangular());
+ assertEquals(0.0, A.sub(qr.q.mmul(qr.r)).normmax(), 1e-10);
+ }
+
+
+ @Test
+ public void qrFloat() {
+ FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+
+ Decompose.QRDecomposition<FloatMatrix> qr = Decompose.qr(A);
+
+ assertEquals(0.0f, FloatMatrix.eye(3).sub(qr.q.transpose().mmul(qr.q)).normmax(), 1e-5f);
+ assertTrue(qr.r.isUpperTriangular());
+ assertEquals(0.0f, A.sub(qr.q.mmul(qr.r)).normmax(), 1e-5f);
+ }
+
+ @Test
+ public void qrRectangularFloat() {
+ //FloatMatrix A = new FloatMatrix(2, 3, 1.0f, 2.0f, 3.0f, 4.0f, 7.0f, 6.0f);
+ FloatMatrix A = FloatMatrix.rand(2, 3);
+
+ Decompose.QRDecomposition<FloatMatrix> qr = Decompose.qr(A);
+
+ assertEquals(0.0f, FloatMatrix.eye(2).sub(qr.q.transpose().mmul(qr.q)).normmax(), 1e-5f);
+ assertTrue(qr.r.isUpperTriangular());
+ assertEquals(0.0f, A.sub(qr.q.mmul(qr.r)).normmax(), 1e-5f);
+ }
+}
diff --git a/src/test/java/org/jblas/TestDoubleMatrix.java b/src/test/java/org/jblas/TestDoubleMatrix.java
index b3f5a49..ce97689 100644
--- a/src/test/java/org/jblas/TestDoubleMatrix.java
+++ b/src/test/java/org/jblas/TestDoubleMatrix.java
@@ -38,610 +38,691 @@ package org.jblas;
import java.io.File;
import java.io.PrintStream;
-import junit.framework.TestCase;
+
+import org.jblas.util.Random;
+
import java.util.Arrays;
+
import static org.jblas.ranges.RangeUtils.*;
-public class TestDoubleMatrix extends TestCase {
+import org.junit.Before;
+import org.junit.Test;
- DoubleMatrix A, B, C, D, E, F;
+import static org.junit.Assert.*;
- public void setUp() {
- A = new DoubleMatrix(4, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0);
- B = new DoubleMatrix(3, 1, 2.0, 4.0, 8.0);
- C = new DoubleMatrix(3, 1, -1.0, 2.0, -3.0);
- D = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
- E = new DoubleMatrix(3, 3, 1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0);
- F = new DoubleMatrix(3, 1, 3.0, 4.0, 7.0);
- }
+public class TestDoubleMatrix {
- public void testConstructionAndSetGet() {
- double[][] dataA = {{1.0, 5.0, 9.0}, {2.0, 6.0, 10.0}, {3.0, 7.0, 11.0}, {4.0, 8.0, 12.0}};
+ //FLOAT// private final float eps = 1e-6f;
+ private final double eps = 1e-16;
- assertEquals(A.rows, 4);
- assertEquals(A.columns, 3);
+ DoubleMatrix A, B, C, D, E, F;
- for (int r = 0; r < 4; r++) {
- for (int c = 0; c < 3; c++) {
- assertEquals(dataA[r][c], A.get(r, c));
- }
- }
- }
-
- public void testSetAndGet() {
- DoubleMatrix M = new DoubleMatrix(3, 3);
-
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- M.put(i, j, i + j);
- }
- }
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- assertEquals((double) i + j, M.get(i, j));
- }
- }
- }
+ @Before
+ public void setUp() {
+ A = new DoubleMatrix(4, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0);
+ B = new DoubleMatrix(3, 1, 2.0, 4.0, 8.0);
+ C = new DoubleMatrix(3, 1, -1.0, 2.0, -3.0);
+ D = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+ E = new DoubleMatrix(3, 3, 1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0);
+ F = new DoubleMatrix(3, 1, 3.0, 4.0, 7.0);
+ }
- public void testCopy() {
- DoubleMatrix M = new DoubleMatrix();
+ @Test
+ public void testConstructionAndSetGet() {
+ double[][] dataA = {{1.0, 5.0, 9.0}, {2.0, 6.0, 10.0}, {3.0, 7.0, 11.0}, {4.0, 8.0, 12.0}};
- assertFalse(M.equals(A));
+ assertEquals(A.rows, 4);
+ assertEquals(A.columns, 3);
- M.copy(A);
- assertEquals(M, A);
+ for (int r = 0; r < 4; r++) {
+ for (int c = 0; c < 3; c++) {
+ assertEquals(dataA[r][c], A.get(r, c), eps);
+ }
}
+ }
- public void testDup() {
- DoubleMatrix M = A.dup();
- assertEquals(M, A);
+ @Test
+ public void testSetAndGet() {
+ DoubleMatrix M = new DoubleMatrix(3, 3);
- M.put(0, 0, 2.0);
- assertFalse(M.equals(A));
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ M.put(i, j, i + j);
+ }
}
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ assertEquals((double) i + j, M.get(i, j), eps);
+ }
+ }
+ }
- public void testResize() {
- DoubleMatrix M = A.dup();
+ @Test
+ public void testGetWithRowIndicesAndSingleColumn() {
+ DoubleMatrix M = new DoubleMatrix(new double[][] {{1, 2}, {3, 4}});
- assertEquals(4, M.rows);
- assertEquals(3, M.columns);
+ assertEquals(new DoubleMatrix(2, 1, 1, 3), M.get(new int[]{0, 1}, 0));
+ assertEquals(new DoubleMatrix(2, 1, 2, 4), M.get(new int[]{0, 1}, 1));
+ }
- M.resize(4, 5);
- assertEquals(4, M.rows);
- assertEquals(5, M.columns);
+ @Test
+ public void testCopy() {
+ DoubleMatrix M = new DoubleMatrix();
- assertEquals(0.0, M.get(3, 4));
- }
+ assertFalse(M.equals(A));
- public void testReshape() {
- DoubleMatrix M = new DoubleMatrix(2, 2, 1.0, 2.0, 3.0, 4.0);
+ M.copy(A);
+ assertEquals(M, A);
+ }
- M.reshape(1, 4);
- assertEquals(1.0, M.get(0, 0));
- assertEquals(4.0, M.get(0, 3));
+ @Test
+ public void testDup() {
+ DoubleMatrix M = A.dup();
+ assertEquals(M, A);
- M.reshape(4, 1);
- assertEquals(1.0, M.get(0, 0));
- assertEquals(4.0, M.get(3, 0));
- }
+ M.put(0, 0, 2.0);
+ assertFalse(M.equals(A));
+ }
- public void testMmul() {
- DoubleMatrix R = A.dup();
- DoubleMatrix result = new DoubleMatrix(4, 1, 94.0, 108.0, 122.0, 136.0);
+ @Test
+ public void testResize() {
+ DoubleMatrix M = A.dup();
- A.mmuli(B, R);
- assertEquals(result, R);
+ assertEquals(4, M.rows);
+ assertEquals(3, M.columns);
- assertEquals(result, A.mmul(B));
+ M.resize(4, 5);
+ assertEquals(4, M.rows);
+ assertEquals(5, M.columns);
- DoubleMatrix resultDE = new DoubleMatrix(3, 3, 14.0, 16.0, 18.0, -26.0, -31.0, -36.0, 38.0, 46.0, 54.0);
+ assertEquals(0.0, M.get(3, 4), eps);
+ }
- // In-place with independent operands
- assertEquals(resultDE, D.mmuli(E, R));
+ @Test
+ public void testReshape() {
+ DoubleMatrix M = new DoubleMatrix(2, 2, 1.0, 2.0, 3.0, 4.0);
- // In-place on this
- R = D.dup();
- assertEquals(resultDE, R.mmuli(E, R));
+ M.reshape(1, 4);
+ assertEquals(1.0, M.get(0, 0), eps);
+ assertEquals(4.0, M.get(0, 3), eps);
- // In-place on this
- R = E.dup();
- assertEquals(resultDE, D.mmuli(R, R));
+ M.reshape(4, 1);
+ assertEquals(1.0, M.get(0, 0), eps);
+ assertEquals(4.0, M.get(3, 0), eps);
+ }
- // Fully dynamic
- assertEquals(resultDE, D.mmul(E));
- }
+ @Test
+ public void testMmul() {
+ DoubleMatrix R = A.dup();
+ DoubleMatrix result = new DoubleMatrix(4, 1, 94.0, 108.0, 122.0, 136.0);
- public void testAdd() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 6.0, 5.0);
+ A.mmuli(B, R);
+ assertEquals(result, R);
- DoubleMatrix R = new DoubleMatrix();
+ assertEquals(result, A.mmul(B));
- // In-place, but independent operands
- B.addi(C, R);
- assertEquals(result, R);
+ DoubleMatrix resultDE = new DoubleMatrix(3, 3, 14.0, 16.0, 18.0, -26.0, -31.0, -36.0, 38.0, 46.0, 54.0);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.addi(C, R));
+ // In-place with independent operands
+ assertEquals(resultDE, D.mmuli(E, R));
- // In-place on other
- R = C.dup();
- assertEquals(result, B.addi(R, R));
+ // In-place on this
+ R = D.dup();
+ assertEquals(resultDE, R.mmuli(E, R));
- // fully dynamic
- assertEquals(result, B.add(C));
+ // In-place on this
+ R = E.dup();
+ assertEquals(resultDE, D.mmuli(R, R));
- result = new DoubleMatrix(3, 1, 3.0, 5.0, 9.0);
+ // Fully dynamic
+ assertEquals(resultDE, D.mmul(E));
+ }
- // In-place, but independent operands
- assertEquals(result, B.addi(1.0, R));
+ @Test
+ public void testAdd() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 6.0, 5.0);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.addi(1.0, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.add(1.0));
- }
+ // In-place, but independent operands
+ B.addi(C, R);
+ assertEquals(result, R);
- public void testSub() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 3.0, 2.0, 11.0);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.addi(C, R));
- DoubleMatrix R = new DoubleMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.addi(R, R));
- // In-place, but independent operands
- assertEquals(result, B.subi(C, R));
+ // fully dynamic
+ assertEquals(result, B.add(C));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.subi(C, R));
+ result = new DoubleMatrix(3, 1, 3.0, 5.0, 9.0);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.subi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.addi(1.0, R));
- // fully dynamic
- assertEquals(result, B.sub(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.addi(1.0, R));
- result = new DoubleMatrix(3, 1, 1.0, 3.0, 7.0);
+ // fully dynamic
+ assertEquals(result, B.add(1.0));
+ }
- // In-place, but independent operands
- assertEquals(result, B.subi(1.0, R));
+ @Test
+ public void testSub() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 3.0, 2.0, 11.0);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.subi(1.0, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.sub(1.0));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.subi(C, R));
- public void testRsub() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 3.0, 2.0, 11.0);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.subi(C, R));
- DoubleMatrix R = new DoubleMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.subi(R, R));
- // In-place, but independent operands
- assertEquals(result, C.rsubi(B, R));
+ // fully dynamic
+ assertEquals(result, B.sub(C));
- // In-place on this
- R = C.dup();
- assertEquals(result, R.rsubi(B, R));
+ result = new DoubleMatrix(3, 1, 1.0, 3.0, 7.0);
- // In-place on other
- R = B.dup();
- assertEquals(result, C.rsubi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.subi(1.0, R));
- // fully dynamic
- assertEquals(result, C.rsub(B));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.subi(1.0, R));
- result = new DoubleMatrix(3, 1, -1.0, -3.0, -7.0);
+ // fully dynamic
+ assertEquals(result, B.sub(1.0));
+ }
- // In-place, but independent operands
- assertEquals(result, B.rsubi(1.0, R));
+ @Test
+ public void testRsub() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 3.0, 2.0, 11.0);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.rsubi(1.0, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.rsub(1.0));
- }
+ // In-place, but independent operands
+ assertEquals(result, C.rsubi(B, R));
- public void testMul() {
- DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 8.0, -24.0);
+ // In-place on this
+ R = C.dup();
+ assertEquals(result, R.rsubi(B, R));
- DoubleMatrix R = new DoubleMatrix();
+ // In-place on other
+ R = B.dup();
+ assertEquals(result, C.rsubi(R, R));
- // In-place, but independent operands
- assertEquals(result, B.muli(C, R));
+ // fully dynamic
+ assertEquals(result, C.rsub(B));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.muli(C, R));
+ result = new DoubleMatrix(3, 1, -1.0, -3.0, -7.0);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.muli(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.rsubi(1.0, R));
- // fully dynamic
- assertEquals(result, B.mul(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.rsubi(1.0, R));
- result = new DoubleMatrix(3, 1, 1.0, 2.0, 4.0);
+ // fully dynamic
+ assertEquals(result, B.rsub(1.0));
+ }
- // In-place, but independent operands
- assertEquals(result, B.muli(0.5, R));
+ @Test
+ public void testMul() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 8.0, -24.0);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.muli(0.5, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.mul(0.5));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.muli(C, R));
- public void testDiv() {
- DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 2.0, -2.666666666);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.muli(C, R));
- DoubleMatrix R = new DoubleMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.muli(R, R));
- // In-place, but independent operands
- assertEquals(result, B.divi(C, R));
+ // fully dynamic
+ assertEquals(result, B.mul(C));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.divi(C, R));
+ result = new DoubleMatrix(3, 1, 1.0, 2.0, 4.0);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.divi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.muli(0.5, R));
- // fully dynamic
- assertEquals(result, B.div(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.muli(0.5, R));
- result = new DoubleMatrix(3, 1, 1.0, 2.0, 4.0);
+ // fully dynamic
+ assertEquals(result, B.mul(0.5));
+ }
- // In-place, but independent operands
- assertEquals(result, B.divi(2.0, R));
+ @Test
+ public void testDiv() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 2.0, -2.666666666);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.divi(2.0, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.div(2.0));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.divi(C, R));
- public void testRdiv() {
- DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 2.0, -2.666666666);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.divi(C, R));
- DoubleMatrix R = new DoubleMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.divi(R, R));
- // In-place, but independent operands
- assertEquals(result, C.rdivi(B, R));
+ // fully dynamic
+ assertEquals(result, B.div(C));
- // In-place on this
- R = C.dup();
- assertEquals(result, R.rdivi(B, R));
+ result = new DoubleMatrix(3, 1, 1.0, 2.0, 4.0);
- // In-place on other
- R = B.dup();
- assertEquals(result, C.rdivi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.divi(2.0, R));
- // fully dynamic
- assertEquals(result, C.rdiv(B));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.divi(2.0, R));
- result = new DoubleMatrix(3, 1, 0.5, 0.25, 0.125);
+ // fully dynamic
+ assertEquals(result, B.div(2.0));
+ }
- // In-place, but independent operands
- assertEquals(result, B.rdivi(1.0, R));
+ @Test
+ public void testRdiv() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, -2.0, 2.0, -2.666666666);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.rdivi(1.0, R));
+ DoubleMatrix R = new DoubleMatrix();
- // fully dynamic
- assertEquals(result, B.rdiv(1.0));
- }
+ // In-place, but independent operands
+ assertEquals(result, C.rdivi(B, R));
- /*# def test_logical(op, result, scalar, result2); <<-EOS
- public void test#{op.upcase}() {
- DoubleMatrix result = new DoubleMatrix(3, 1, #{result});
- DoubleMatrix result2 = new DoubleMatrix(3, 1, #{result2});
- DoubleMatrix R = new DoubleMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.#{op}i(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.#{op}i(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.#{op}i(F, R));
-
- // fully dynamic
- assertEquals(result, B.#{op}(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.#{op}i(#{scalar}, R));
-
- // fully dynamic
- assertEquals(result2, B.#{op}(#{scalar}));
- }
- EOS
- end
- #*/
- /*# test_logical('lt', '1.0, 0.0, 0.0', 4.0, '1.0, 0.0, 0.0') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testLT() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 0.0, 0.0);
- DoubleMatrix result2 = new DoubleMatrix(3, 1, 1.0, 0.0, 0.0);
- DoubleMatrix R = new DoubleMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.lti(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.lti(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.lti(F, R));
-
- // fully dynamic
- assertEquals(result, B.lt(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.lti(4.0, R));
-
- // fully dynamic
- assertEquals(result2, B.lt(4.0));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('le', '1.0, 1.0, 0.0', 4.0, '1.0, 1.0, 0.0') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testLE() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 1.0, 0.0);
- DoubleMatrix result2 = new DoubleMatrix(3, 1, 1.0, 1.0, 0.0);
- DoubleMatrix R = new DoubleMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.lei(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.lei(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.lei(F, R));
-
- // fully dynamic
- assertEquals(result, B.le(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.lei(4.0, R));
-
- // fully dynamic
- assertEquals(result2, B.le(4.0));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('gt', '0.0, 0.0, 1.0', 4.0, '0.0, 0.0, 1.0') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testGT() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 0.0, 0.0, 1.0);
- DoubleMatrix result2 = new DoubleMatrix(3, 1, 0.0, 0.0, 1.0);
- DoubleMatrix R = new DoubleMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.gti(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.gti(R, R));
-
- // in-place in this
+ // In-place on this
+ R = C.dup();
+ assertEquals(result, R.rdivi(B, R));
+
+ // In-place on other
R = B.dup();
- assertEquals(result, R.gti(F, R));
-
+ assertEquals(result, C.rdivi(R, R));
+
// fully dynamic
- assertEquals(result, B.gt(F));
-
- // in-place in this
+ assertEquals(result, C.rdiv(B));
+
+ result = new DoubleMatrix(3, 1, 0.5, 0.25, 0.125);
+
+ // In-place, but independent operands
+ assertEquals(result, B.rdivi(1.0, R));
+
+ // In-place on this
R = B.dup();
- assertEquals(result2, R.gti(4.0, R));
-
+ assertEquals(result, R.rdivi(1.0, R));
+
// fully dynamic
- assertEquals(result2, B.gt(4.0));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('ge', '0.0, 1.0, 1.0', 4.0, '0.0, 1.0, 1.0') #*/
+ assertEquals(result, B.rdiv(1.0));
+ }
+
+ /*# def test_logical(op, result, scalar, result2); <<-EOS
+ @Test
+ public void test#{op.upcase}() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, #{result});
+ DoubleMatrix result2 = new DoubleMatrix(3, 1, #{result2});
+ DoubleMatrix R = new DoubleMatrix();
+
+ // in-place but independent operands
+ assertEquals(result, B.#{op}i(F, R));
+
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.#{op}i(R, R));
+
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.#{op}i(F, R));
+
+ // fully dynamic
+ assertEquals(result, B.#{op}(F));
+
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.#{op}i(#{scalar}, R));
+
+ // fully dynamic
+ assertEquals(result2, B.#{op}(#{scalar}));
+ }
+ EOS
+ end
+ #*/
+ /*# test_logical('lt', '1.0, 0.0, 0.0', 4.0, '1.0, 0.0, 0.0') #*/
//RJPP-BEGIN------------------------------------------------------------
- public void testGE() {
- DoubleMatrix result = new DoubleMatrix(3, 1, 0.0, 1.0, 1.0);
- DoubleMatrix result2 = new DoubleMatrix(3, 1, 0.0, 1.0, 1.0);
- DoubleMatrix R = new DoubleMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.gei(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.gei(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.gei(F, R));
-
- // fully dynamic
- assertEquals(result, B.ge(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.gei(4.0, R));
-
- // fully dynamic
- assertEquals(result2, B.ge(4.0));
- }
-//RJPP-END--------------------------------------------------------------
- public void testMinMax() {
- assertEquals(1.0, A.min());
- assertEquals(12.0, A.max());
- }
+ @Test
+ public void testLT() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 0.0, 0.0);
+ DoubleMatrix result2 = new DoubleMatrix(3, 1, 1.0, 0.0, 0.0);
+ DoubleMatrix R = new DoubleMatrix();
- public void testArgMinMax() {
- assertEquals(0, A.argmin());
- assertEquals(11, A.argmax());
- }
+ // in-place but independent operands
+ assertEquals(result, B.lti(F, R));
- public void testTranspose() {
- DoubleMatrix At = A.transpose();
- assertEquals(1.0, At.get(0, 0));
- assertEquals(2.0, At.get(0, 1));
- assertEquals(5.0, At.get(1, 0));
- }
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.lti(R, R));
- public void testGetRowVector() {
- for (int r = 0; r < A.rows; r++) {
- A.getRow(r);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.lti(F, R));
- for (int c = 0; c < A.columns; c++) {
- A.getColumn(c);
- }
+ // fully dynamic
+ assertEquals(result, B.lt(F));
- A.addiRowVector(new DoubleMatrix(3, 1, 10.0, 100.0, 1000.0));
- A.addiColumnVector(new DoubleMatrix(1, 4, 10.0, 100.0, 1000.0, 10000.0));
- }
-
- public void testPairwiseDistance() {
- DoubleMatrix D = Geometry.pairwiseSquaredDistances(A, A);
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.lti(4.0, R));
- DoubleMatrix X = new DoubleMatrix(1, 3, 1.0, 0.0, -1.0);
+ // fully dynamic
+ assertEquals(result2, B.lt(4.0));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('le', '1.0, 1.0, 0.0', 4.0, '1.0, 1.0, 0.0') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testLE() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 1.0, 1.0, 0.0);
+ DoubleMatrix result2 = new DoubleMatrix(3, 1, 1.0, 1.0, 0.0);
+ DoubleMatrix R = new DoubleMatrix();
- Geometry.pairwiseSquaredDistances(X, X);
+ // in-place but independent operands
+ assertEquals(result, B.lei(F, R));
- DoubleMatrix A1 = new DoubleMatrix(1, 2, 1.0, 2.0);
- DoubleMatrix A2 = new DoubleMatrix(1, 3, 1.0, 2.0, 3.0);
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.lei(R, R));
- Geometry.pairwiseSquaredDistances(A1, A2);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.lei(F, R));
- public void testSwapColumns() {
- DoubleMatrix AA = A.dup();
+ // fully dynamic
+ assertEquals(result, B.le(F));
- AA.swapColumns(1, 2);
- assertEquals(new DoubleMatrix(4, 3, 1.0, 2.0, 3.0, 4.0, 9.0, 10.0, 11.0, 12.0, 5.0, 6.0, 7.0, 8.0), AA);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.lei(4.0, R));
- public void testSwapRows() {
- DoubleMatrix AA = A.dup();
+ // fully dynamic
+ assertEquals(result2, B.le(4.0));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('gt', '0.0, 0.0, 1.0', 4.0, '0.0, 0.0, 1.0') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testGT() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 0.0, 0.0, 1.0);
+ DoubleMatrix result2 = new DoubleMatrix(3, 1, 0.0, 0.0, 1.0);
+ DoubleMatrix R = new DoubleMatrix();
- AA.swapRows(1, 2);
- assertEquals(new DoubleMatrix(4, 3, 1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0, 9.0, 11.0, 10.0, 12.0), AA);
- }
+ // in-place but independent operands
+ assertEquals(result, B.gti(F, R));
- public void testSolve() {
- DoubleMatrix AA = new DoubleMatrix(3, 3, 3.0, 5.0, 6.0, 1.0, 0.0, 0.0, 2.0, 4.0, 0.0);
- DoubleMatrix BB = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.gti(R, R));
- DoubleMatrix Adup = AA.dup();
- DoubleMatrix Bdup = BB.dup();
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.gti(F, R));
- DoubleMatrix X = Solve.solve(AA, BB);
+ // fully dynamic
+ assertEquals(result, B.gt(F));
- assertEquals(Adup, AA);
- assertEquals(Bdup, BB);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.gti(4.0, R));
- public void testConstructFromArray() {
- double[][] data = {
- {1.0, 2.0, 3.0},
- {4.0, 5.0, 6.0},
- {7.0, 8.0, 9.0}
- };
+ // fully dynamic
+ assertEquals(result2, B.gt(4.0));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('ge', '0.0, 1.0, 1.0', 4.0, '0.0, 1.0, 1.0') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testGE() {
+ DoubleMatrix result = new DoubleMatrix(3, 1, 0.0, 1.0, 1.0);
+ DoubleMatrix result2 = new DoubleMatrix(3, 1, 0.0, 1.0, 1.0);
+ DoubleMatrix R = new DoubleMatrix();
- DoubleMatrix A = new DoubleMatrix(data);
+ // in-place but independent operands
+ assertEquals(result, B.gei(F, R));
- for (int r = 0; r < 3; r++) {
- for (int c = 0; c < 3; c++) {
- assertEquals(data[r][c], A.get(r, c));
- }
- }
- }
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.gei(R, R));
- public void testDiag() {
- DoubleMatrix A = new DoubleMatrix(new double[][]{
- {1.0, 2.0, 3.0},
- {4.0, 5.0, 6.0},
- {7.0, 8.0, 9.0}
- });
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.gei(F, R));
- assertEquals(new DoubleMatrix(3, 1, 1.0, 5.0, 9.0), A.diag());
+ // fully dynamic
+ assertEquals(result, B.ge(F));
- assertEquals(new DoubleMatrix(new double[][]{
- {1.0, 0.0, 0.0},
- {0.0, 2.0, 0.0},
- {0.0, 0.0, 3.0}
- }), DoubleMatrix.diag(new DoubleMatrix(3, 1, 1.0, 2.0, 3.0)));
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.gei(4.0, R));
- public void testColumnAndRowMinMax() {
- assertEquals(new DoubleMatrix(1, 3, 1.0, 5.0, 9.0), A.columnMins());
- assertEquals(new DoubleMatrix(4, 1, 1.0, 2.0, 3.0, 4.0), A.rowMins());
- assertEquals(new DoubleMatrix(1, 3, 4.0, 8.0, 12.0), A.columnMaxs());
- assertEquals(new DoubleMatrix(4, 1, 9.0, 10.0, 11.0, 12.0), A.rowMaxs());
- int[] i = A.columnArgmins();
- assertEquals(0, i[0]);
- assertEquals(0, i[1]);
- assertEquals(0, i[2]);
- i = A.columnArgmaxs();
- assertEquals(3, i[0]);
- assertEquals(3, i[1]);
- assertEquals(3, i[2]);
- i = A.rowArgmins();
- assertEquals(0, i[0]);
- assertEquals(0, i[1]);
- assertEquals(0, i[2]);
- assertEquals(0, i[3]);
- i = A.rowArgmaxs();
- assertEquals(2, i[0]);
- assertEquals(2, i[1]);
- assertEquals(2, i[2]);
- assertEquals(2, i[3]);
- }
+ // fully dynamic
+ assertEquals(result2, B.ge(4.0));
+ }
+//RJPP-END--------------------------------------------------------------
+ @Test
+ public void testMinMax() {
+ assertEquals(1.0, A.min(), eps);
+ assertEquals(12.0, A.max(), eps);
+ }
- public void testToArray() {
- assertTrue(Arrays.equals(new double[]{2.0, 4.0, 8.0}, B.toArray()));
- assertTrue(Arrays.equals(new int[]{2, 4, 8}, B.toIntArray()));
- assertTrue(Arrays.equals(new boolean[]{true, true, true}, B.toBooleanArray()));
- }
+ @Test
+ public void testArgMinMax() {
+ assertEquals(0, A.argmin(), eps);
+ assertEquals(11, A.argmax(), eps);
+ }
- public void testLoadAsciiFile() {
- try {
- File f = File.createTempFile("jblas-test", "txt");
- f.deleteOnExit();
- PrintStream out = new PrintStream(f);
- out.println("1.0 2.0 3.0");
- out.println("4.0 5.0 6.0");
- out.close();
-
- DoubleMatrix result = DoubleMatrix.loadAsciiFile(f.getAbsolutePath());
- assertEquals(new DoubleMatrix(2, 3, 1.0, 4.0, 2.0, 5.0, 3.0, 6.0), result);
- } catch (Exception e) {
- fail("Caught exception " + e);
- }
- }
-
- public void testRanges() {
- // Hm... Broken?
- //System.out.printf("Ranges: %s\n", A.get(interval(0, 2), interval(0, 1)).toString());
- //assertEquals(new DoubleMatrix(3, 2, 1.0, 2.0, 3.0, 5.0, 6.0, 7.0), );
- }
+ @Test
+ public void testTranspose() {
+ DoubleMatrix At = A.transpose();
+ assertEquals(1.0, At.get(0, 0), eps);
+ assertEquals(2.0, At.get(0, 1), eps);
+ assertEquals(5.0, At.get(1, 0), eps);
+ }
+
+ @Test
+ public void testGetRowVector() {
+ for (int r = 0; r < A.rows; r++) {
+ A.getRow(r);
+ }
+
+ for (int c = 0; c < A.columns; c++) {
+ A.getColumn(c);
+ }
+
+ A.addiRowVector(new DoubleMatrix(3, 1, 10.0, 100.0, 1000.0));
+ A.addiColumnVector(new DoubleMatrix(1, 4, 10.0, 100.0, 1000.0, 10000.0));
+ }
+
+ @Test
+ public void testPairwiseDistance() {
+ DoubleMatrix D = Geometry.pairwiseSquaredDistances(A, A);
+
+ DoubleMatrix X = new DoubleMatrix(1, 3, 1.0, 0.0, -1.0);
+
+ Geometry.pairwiseSquaredDistances(X, X);
+
+ DoubleMatrix A1 = new DoubleMatrix(1, 2, 1.0, 2.0);
+ DoubleMatrix A2 = new DoubleMatrix(1, 3, 1.0, 2.0, 3.0);
+
+ Geometry.pairwiseSquaredDistances(A1, A2);
+ }
+
+ @Test
+ public void testSwapColumns() {
+ DoubleMatrix AA = A.dup();
+
+ AA.swapColumns(1, 2);
+ assertEquals(new DoubleMatrix(4, 3, 1.0, 2.0, 3.0, 4.0, 9.0, 10.0, 11.0, 12.0, 5.0, 6.0, 7.0, 8.0), AA);
+ }
+
+ @Test
+ public void testSwapRows() {
+ DoubleMatrix AA = A.dup();
+
+ AA.swapRows(1, 2);
+ assertEquals(new DoubleMatrix(4, 3, 1.0, 3.0, 2.0, 4.0, 5.0, 7.0, 6.0, 8.0, 9.0, 11.0, 10.0, 12.0), AA);
+ }
+
+ @Test
+ public void testSolve() {
+ DoubleMatrix AA = new DoubleMatrix(3, 3, 3.0, 5.0, 6.0, 1.0, 0.0, 0.0, 2.0, 4.0, 0.0);
+ DoubleMatrix BB = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
+
+ DoubleMatrix Adup = AA.dup();
+ DoubleMatrix Bdup = BB.dup();
+
+ DoubleMatrix X = Solve.solve(AA, BB);
+
+ assertEquals(Adup, AA);
+ assertEquals(Bdup, BB);
+ }
+
+ @Test
+ public void testConstructFromArray() {
+ double[][] data = {
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0}
+ };
+
+ DoubleMatrix A = new DoubleMatrix(data);
+
+ for (int r = 0; r < 3; r++) {
+ for (int c = 0; c < 3; c++) {
+ assertEquals(data[r][c], A.get(r, c), eps);
+ }
+ }
+ }
+
+ @Test
+ public void testDiag() {
+ DoubleMatrix A = new DoubleMatrix(new double[][]{
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ {7.0, 8.0, 9.0}
+ });
+
+ assertEquals(new DoubleMatrix(3, 1, 1.0, 5.0, 9.0), A.diag());
+
+ assertEquals(new DoubleMatrix(new double[][]{
+ {1.0, 0.0, 0.0},
+ {0.0, 2.0, 0.0},
+ {0.0, 0.0, 3.0}
+ }), DoubleMatrix.diag(new DoubleMatrix(3, 1, 1.0, 2.0, 3.0)));
+ }
+
+ @Test
+ public void testColumnAndRowMinMax() {
+ assertEquals(new DoubleMatrix(1, 3, 1.0, 5.0, 9.0), A.columnMins());
+ assertEquals(new DoubleMatrix(4, 1, 1.0, 2.0, 3.0, 4.0), A.rowMins());
+ assertEquals(new DoubleMatrix(1, 3, 4.0, 8.0, 12.0), A.columnMaxs());
+ assertEquals(new DoubleMatrix(4, 1, 9.0, 10.0, 11.0, 12.0), A.rowMaxs());
+ int[] i = A.columnArgmins();
+ assertEquals(0, i[0]);
+ assertEquals(0, i[1]);
+ assertEquals(0, i[2]);
+ i = A.columnArgmaxs();
+ assertEquals(3, i[0]);
+ assertEquals(3, i[1]);
+ assertEquals(3, i[2]);
+ i = A.rowArgmins();
+ assertEquals(0, i[0]);
+ assertEquals(0, i[1]);
+ assertEquals(0, i[2]);
+ assertEquals(0, i[3]);
+ i = A.rowArgmaxs();
+ assertEquals(2, i[0]);
+ assertEquals(2, i[1]);
+ assertEquals(2, i[2]);
+ assertEquals(2, i[3]);
+ }
+
+ @Test
+ public void testToArray() {
+ assertTrue(Arrays.equals(new double[]{2.0, 4.0, 8.0}, B.toArray()));
+ assertTrue(Arrays.equals(new int[]{2, 4, 8}, B.toIntArray()));
+ assertTrue(Arrays.equals(new boolean[]{true, true, true}, B.toBooleanArray()));
+ }
+
+ @Test
+ public void testLoadAsciiFile() {
+ try {
+ File f = File.createTempFile("jblas-test", "txt");
+ f.deleteOnExit();
+ PrintStream out = new PrintStream(f);
+ out.println("1.0 2.0 3.0");
+ out.println("4.0 5.0 6.0");
+ out.close();
+
+ DoubleMatrix result = DoubleMatrix.loadAsciiFile(f.getAbsolutePath());
+ assertEquals(new DoubleMatrix(2, 3, 1.0, 4.0, 2.0, 5.0, 3.0, 6.0), result);
+ } catch (Exception e) {
+ fail("Caught exception " + e);
+ }
+ }
+
+ @Test
+ public void testRanges() {
+ DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+ DoubleMatrix B = new DoubleMatrix(2, 3, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0);
+
+ A.put(interval(0, 2), interval(0, 3), B);
+
+ /*assertEquals(-1.0, A.get(0, 0));
+ assertEquals(-2.0, A.get(0, 1));
+ assertEquals(-3.0, A.get(0, 2));
+ assertEquals(-4.0, A.get(1, 0));
+ assertEquals(-5.0, A.get(1, 1));
+ assertEquals(-6.0, A.get(1, 2));*/
+ }
+
+ @Test
+ public void testRandWithSeed() {
+ Random.seed(1);
+ DoubleMatrix A = DoubleMatrix.rand(3, 3);
+ Random.seed(1);
+ DoubleMatrix B = DoubleMatrix.rand(3, 3);
+ assertEquals(0.0, A.sub(B).normmax(), 1e-9);
+ }
+
+ @Test
+ public void testToString() {
+ // We have to be a bit cautious here because my Double => Float converter scripts will
+ // add a "f" to every floating point number, even in the strings. Therefore, I
+ // explicitly remove all "f"s
+ assertEquals("[1.000000, 5.000000, 9.000000; 2.000000, 6.000000, 10.000000; 3.000000, 7.000000, 11.000000; 4.000000, 8.000000, 12.000000]".replaceAll("f", ""), A.toString());
+
+ assertEquals("[1.0, 5.0, 9.0; 2.0, 6.0, 10.0; 3.0, 7.0, 11.0; 4.0, 8.0, 12.0]".replaceAll("f", ""), A.toString("%.1f"));
+
+ assertEquals("{1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0}".replaceAll("f", ""), A.toString("%.1f", "{", "}", " ", "; "));
+ }
}
diff --git a/src/test/java/org/jblas/TestEigen.java b/src/test/java/org/jblas/TestEigen.java
index 5aa7215..276d78e 100644
--- a/src/test/java/org/jblas/TestEigen.java
+++ b/src/test/java/org/jblas/TestEigen.java
@@ -42,57 +42,60 @@ package org.jblas;
import junit.framework.TestCase;
import org.jblas.util.Logger;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
/**
* Test Class for org.jblas.Eigen
- *
+ *
* @author mikio
*/
-public class TestEigen extends TestCase {
- public TestEigen(String testName) {
- super(testName);
- Logger.getLogger().setLevel(Logger.DEBUG);
- }
+public class TestEigen {
+
+ private final double eps = 1e-10;
- public void testEigenvalues() {
- DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0);
+ @Test
+ public void testEigenvalues() {
+ DoubleMatrix A = new DoubleMatrix(2, 2, 3.0, -3.0, 1.0, 1.0);
- ComplexDoubleMatrix E = Eigen.eigenvalues(A);
+ ComplexDoubleMatrix E = Eigen.eigenvalues(A);
- //System.out.printf("E = %s\n", E.toString());
+ ComplexDoubleMatrix[] EV = Eigen.eigenvectors(A);
- ComplexDoubleMatrix[] EV = Eigen.eigenvectors(A);
+ ComplexDoubleMatrix X = EV[0];
+ ComplexDoubleMatrix L = EV[1];
- //System.out.printf("values = %s\n", EV[1].toString());
- //System.out.printf("vectors = %s\n", EV[0].toString());
- }
+ assertEquals(0.0, A.toComplex().mmul(X).sub(X.mmul(L)).norm2(), eps);
+ }
- public void testSymmetricEigenvalues() {
- DoubleMatrix A = new DoubleMatrix(new double[][]{
- {3.0, 1.0, 0.5},
- {1.0, 3.0, 1.0},
- {0.5, 1.0, 3.0}
- });
+ @Test
+ public void testSymmetricEigenvalues() {
+ DoubleMatrix A = new DoubleMatrix(new double[][]{
+ {3.0, 1.0, 0.5},
+ {1.0, 3.0, 1.0},
+ {0.5, 1.0, 3.0}
+ });
- DoubleMatrix B = new DoubleMatrix(new double[][]{
- {2.0, 0.1, 0.0},
- {0.1, 2.0, 0.1},
- {0.0, 0.1, 2.0}
- });
+ DoubleMatrix B = new DoubleMatrix(new double[][]{
+ {2.0, 0.1, 0.0},
+ {0.1, 2.0, 0.1},
+ {0.0, 0.1, 2.0}
+ });
- DoubleMatrix[] results = Eigen.symmetricGeneralizedEigenvectors(A, B);
+ DoubleMatrix[] results = Eigen.symmetricGeneralizedEigenvectors(A, B);
- DoubleMatrix V = results[0];
- DoubleMatrix L = results[1];
+ DoubleMatrix V = results[0];
+ DoubleMatrix L = results[1];
- DoubleMatrix LHS = A.mmul(V);
- DoubleMatrix RHS = B.mmul(V).mmul(DoubleMatrix.diag(L));
+ DoubleMatrix LHS = A.mmul(V);
+ DoubleMatrix RHS = B.mmul(V).mmul(DoubleMatrix.diag(L));
- assertEquals(0.0, LHS.sub(RHS).normmax(), 1e-3);
+ assertEquals(0.0, LHS.sub(RHS).normmax(), eps);
- DoubleMatrix eigenvalues = Eigen.symmetricGeneralizedEigenvalues(A, B);
+ DoubleMatrix eigenvalues = Eigen.symmetricGeneralizedEigenvalues(A, B);
- assertEquals(0.0, eigenvalues.sub(L).normmax(), 1e-3);
- }
+ assertEquals(0.0, eigenvalues.sub(L).normmax(), eps);
+ }
}
diff --git a/src/test/java/org/jblas/TestFloatMatrix.java b/src/test/java/org/jblas/TestFloatMatrix.java
index c669188..17b04b2 100644
--- a/src/test/java/org/jblas/TestFloatMatrix.java
+++ b/src/test/java/org/jblas/TestFloatMatrix.java
@@ -38,610 +38,690 @@ package org.jblas;
import java.io.File;
import java.io.PrintStream;
-import junit.framework.TestCase;
+
+import org.jblas.util.Random;
+
import java.util.Arrays;
+
import static org.jblas.ranges.RangeUtils.*;
-public class TestFloatMatrix extends TestCase {
+import org.junit.Before;
+import org.junit.Test;
- FloatMatrix A, B, C, D, E, F;
+import static org.junit.Assert.*;
- public void setUp() {
- A = new FloatMatrix(4, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f);
- B = new FloatMatrix(3, 1, 2.0f, 4.0f, 8.0f);
- C = new FloatMatrix(3, 1, -1.0f, 2.0f, -3.0f);
- D = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
- E = new FloatMatrix(3, 3, 1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f, 9.0f);
- F = new FloatMatrix(3, 1, 3.0f, 4.0f, 7.0f);
- }
+public class TestFloatMatrix {
- public void testConstructionAndSetGet() {
- float[][] dataA = {{1.0f, 5.0f, 9.0f}, {2.0f, 6.0f, 10.0f}, {3.0f, 7.0f, 11.0f}, {4.0f, 8.0f, 12.0f}};
+ private final float eps = 1e-6f;
- assertEquals(A.rows, 4);
- assertEquals(A.columns, 3);
+ FloatMatrix A, B, C, D, E, F;
- for (int r = 0; r < 4; r++) {
- for (int c = 0; c < 3; c++) {
- assertEquals(dataA[r][c], A.get(r, c));
- }
- }
- }
-
- public void testSetAndGet() {
- FloatMatrix M = new FloatMatrix(3, 3);
-
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- M.put(i, j, i + j);
- }
- }
- for (int i = 0; i < 3; i++) {
- for (int j = 0; j < 3; j++) {
- assertEquals((float) i + j, M.get(i, j));
- }
- }
- }
+ @Before
+ public void setUp() {
+ A = new FloatMatrix(4, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f);
+ B = new FloatMatrix(3, 1, 2.0f, 4.0f, 8.0f);
+ C = new FloatMatrix(3, 1, -1.0f, 2.0f, -3.0f);
+ D = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+ E = new FloatMatrix(3, 3, 1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f, 7.0f, -8.0f, 9.0f);
+ F = new FloatMatrix(3, 1, 3.0f, 4.0f, 7.0f);
+ }
- public void testCopy() {
- FloatMatrix M = new FloatMatrix();
+ @Test
+ public void testConstructionAndSetGet() {
+ float[][] dataA = {{1.0f, 5.0f, 9.0f}, {2.0f, 6.0f, 10.0f}, {3.0f, 7.0f, 11.0f}, {4.0f, 8.0f, 12.0f}};
- assertFalse(M.equals(A));
+ assertEquals(A.rows, 4);
+ assertEquals(A.columns, 3);
- M.copy(A);
- assertEquals(M, A);
+ for (int r = 0; r < 4; r++) {
+ for (int c = 0; c < 3; c++) {
+ assertEquals(dataA[r][c], A.get(r, c), eps);
+ }
}
+ }
- public void testDup() {
- FloatMatrix M = A.dup();
- assertEquals(M, A);
+ @Test
+ public void testSetAndGet() {
+ FloatMatrix M = new FloatMatrix(3, 3);
- M.put(0, 0, 2.0f);
- assertFalse(M.equals(A));
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ M.put(i, j, i + j);
+ }
}
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 3; j++) {
+ assertEquals((float) i + j, M.get(i, j), eps);
+ }
+ }
+ }
- public void testResize() {
- FloatMatrix M = A.dup();
+ @Test
+ public void testGetWithRowIndicesAndSingleColumn() {
+ FloatMatrix M = new FloatMatrix(new float[][] {{1, 2}, {3, 4}});
- assertEquals(4, M.rows);
- assertEquals(3, M.columns);
+ assertEquals(new FloatMatrix(2, 1, 1, 3), M.get(new int[]{0, 1}, 0));
+ assertEquals(new FloatMatrix(2, 1, 2, 4), M.get(new int[]{0, 1}, 1));
+ }
- M.resize(4, 5);
- assertEquals(4, M.rows);
- assertEquals(5, M.columns);
+ @Test
+ public void testCopy() {
+ FloatMatrix M = new FloatMatrix();
- assertEquals(0.0f, M.get(3, 4));
- }
+ assertFalse(M.equals(A));
- public void testReshape() {
- FloatMatrix M = new FloatMatrix(2, 2, 1.0f, 2.0f, 3.0f, 4.0f);
+ M.copy(A);
+ assertEquals(M, A);
+ }
- M.reshape(1, 4);
- assertEquals(1.0f, M.get(0, 0));
- assertEquals(4.0f, M.get(0, 3));
+ @Test
+ public void testDup() {
+ FloatMatrix M = A.dup();
+ assertEquals(M, A);
- M.reshape(4, 1);
- assertEquals(1.0f, M.get(0, 0));
- assertEquals(4.0f, M.get(3, 0));
- }
+ M.put(0, 0, 2.0f);
+ assertFalse(M.equals(A));
+ }
- public void testMmul() {
- FloatMatrix R = A.dup();
- FloatMatrix result = new FloatMatrix(4, 1, 94.0f, 108.0f, 122.0f, 136.0f);
+ @Test
+ public void testResize() {
+ FloatMatrix M = A.dup();
- A.mmuli(B, R);
- assertEquals(result, R);
+ assertEquals(4, M.rows);
+ assertEquals(3, M.columns);
- assertEquals(result, A.mmul(B));
+ M.resize(4, 5);
+ assertEquals(4, M.rows);
+ assertEquals(5, M.columns);
- FloatMatrix resultDE = new FloatMatrix(3, 3, 14.0f, 16.0f, 18.0f, -26.0f, -31.0f, -36.0f, 38.0f, 46.0f, 54.0f);
+ assertEquals(0.0f, M.get(3, 4), eps);
+ }
- // In-place with independent operands
- assertEquals(resultDE, D.mmuli(E, R));
+ @Test
+ public void testReshape() {
+ FloatMatrix M = new FloatMatrix(2, 2, 1.0f, 2.0f, 3.0f, 4.0f);
- // In-place on this
- R = D.dup();
- assertEquals(resultDE, R.mmuli(E, R));
+ M.reshape(1, 4);
+ assertEquals(1.0f, M.get(0, 0), eps);
+ assertEquals(4.0f, M.get(0, 3), eps);
- // In-place on this
- R = E.dup();
- assertEquals(resultDE, D.mmuli(R, R));
+ M.reshape(4, 1);
+ assertEquals(1.0f, M.get(0, 0), eps);
+ assertEquals(4.0f, M.get(3, 0), eps);
+ }
- // Fully dynamic
- assertEquals(resultDE, D.mmul(E));
- }
+ @Test
+ public void testMmul() {
+ FloatMatrix R = A.dup();
+ FloatMatrix result = new FloatMatrix(4, 1, 94.0f, 108.0f, 122.0f, 136.0f);
- public void testAdd() {
- FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 6.0f, 5.0f);
+ A.mmuli(B, R);
+ assertEquals(result, R);
- FloatMatrix R = new FloatMatrix();
+ assertEquals(result, A.mmul(B));
- // In-place, but independent operands
- B.addi(C, R);
- assertEquals(result, R);
+ FloatMatrix resultDE = new FloatMatrix(3, 3, 14.0f, 16.0f, 18.0f, -26.0f, -31.0f, -36.0f, 38.0f, 46.0f, 54.0f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.addi(C, R));
+ // In-place with independent operands
+ assertEquals(resultDE, D.mmuli(E, R));
- // In-place on other
- R = C.dup();
- assertEquals(result, B.addi(R, R));
+ // In-place on this
+ R = D.dup();
+ assertEquals(resultDE, R.mmuli(E, R));
- // fully dynamic
- assertEquals(result, B.add(C));
+ // In-place on this
+ R = E.dup();
+ assertEquals(resultDE, D.mmuli(R, R));
- result = new FloatMatrix(3, 1, 3.0f, 5.0f, 9.0f);
+ // Fully dynamic
+ assertEquals(resultDE, D.mmul(E));
+ }
- // In-place, but independent operands
- assertEquals(result, B.addi(1.0f, R));
+ @Test
+ public void testAdd() {
+ FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 6.0f, 5.0f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.addi(1.0f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.add(1.0f));
- }
+ // In-place, but independent operands
+ B.addi(C, R);
+ assertEquals(result, R);
- public void testSub() {
- FloatMatrix result = new FloatMatrix(3, 1, 3.0f, 2.0f, 11.0f);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.addi(C, R));
- FloatMatrix R = new FloatMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.addi(R, R));
- // In-place, but independent operands
- assertEquals(result, B.subi(C, R));
+ // fully dynamic
+ assertEquals(result, B.add(C));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.subi(C, R));
+ result = new FloatMatrix(3, 1, 3.0f, 5.0f, 9.0f);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.subi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.addi(1.0f, R));
- // fully dynamic
- assertEquals(result, B.sub(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.addi(1.0f, R));
- result = new FloatMatrix(3, 1, 1.0f, 3.0f, 7.0f);
+ // fully dynamic
+ assertEquals(result, B.add(1.0f));
+ }
- // In-place, but independent operands
- assertEquals(result, B.subi(1.0f, R));
+ @Test
+ public void testSub() {
+ FloatMatrix result = new FloatMatrix(3, 1, 3.0f, 2.0f, 11.0f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.subi(1.0f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.sub(1.0f));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.subi(C, R));
- public void testRsub() {
- FloatMatrix result = new FloatMatrix(3, 1, 3.0f, 2.0f, 11.0f);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.subi(C, R));
- FloatMatrix R = new FloatMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.subi(R, R));
- // In-place, but independent operands
- assertEquals(result, C.rsubi(B, R));
+ // fully dynamic
+ assertEquals(result, B.sub(C));
- // In-place on this
- R = C.dup();
- assertEquals(result, R.rsubi(B, R));
+ result = new FloatMatrix(3, 1, 1.0f, 3.0f, 7.0f);
- // In-place on other
- R = B.dup();
- assertEquals(result, C.rsubi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.subi(1.0f, R));
- // fully dynamic
- assertEquals(result, C.rsub(B));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.subi(1.0f, R));
- result = new FloatMatrix(3, 1, -1.0f, -3.0f, -7.0f);
+ // fully dynamic
+ assertEquals(result, B.sub(1.0f));
+ }
- // In-place, but independent operands
- assertEquals(result, B.rsubi(1.0f, R));
+ @Test
+ public void testRsub() {
+ FloatMatrix result = new FloatMatrix(3, 1, 3.0f, 2.0f, 11.0f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.rsubi(1.0f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.rsub(1.0f));
- }
+ // In-place, but independent operands
+ assertEquals(result, C.rsubi(B, R));
- public void testMul() {
- FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 8.0f, -24.0f);
+ // In-place on this
+ R = C.dup();
+ assertEquals(result, R.rsubi(B, R));
- FloatMatrix R = new FloatMatrix();
+ // In-place on other
+ R = B.dup();
+ assertEquals(result, C.rsubi(R, R));
- // In-place, but independent operands
- assertEquals(result, B.muli(C, R));
+ // fully dynamic
+ assertEquals(result, C.rsub(B));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.muli(C, R));
+ result = new FloatMatrix(3, 1, -1.0f, -3.0f, -7.0f);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.muli(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.rsubi(1.0f, R));
- // fully dynamic
- assertEquals(result, B.mul(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.rsubi(1.0f, R));
- result = new FloatMatrix(3, 1, 1.0f, 2.0f, 4.0f);
+ // fully dynamic
+ assertEquals(result, B.rsub(1.0f));
+ }
- // In-place, but independent operands
- assertEquals(result, B.muli(0.5f, R));
+ @Test
+ public void testMul() {
+ FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 8.0f, -24.0f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.muli(0.5f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.mul(0.5f));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.muli(C, R));
- public void testDiv() {
- FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 2.0f, -2.666666666f);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.muli(C, R));
- FloatMatrix R = new FloatMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.muli(R, R));
- // In-place, but independent operands
- assertEquals(result, B.divi(C, R));
+ // fully dynamic
+ assertEquals(result, B.mul(C));
- // In-place on this
- R = B.dup();
- assertEquals(result, R.divi(C, R));
+ result = new FloatMatrix(3, 1, 1.0f, 2.0f, 4.0f);
- // In-place on other
- R = C.dup();
- assertEquals(result, B.divi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.muli(0.5f, R));
- // fully dynamic
- assertEquals(result, B.div(C));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.muli(0.5f, R));
- result = new FloatMatrix(3, 1, 1.0f, 2.0f, 4.0f);
+ // fully dynamic
+ assertEquals(result, B.mul(0.5f));
+ }
- // In-place, but independent operands
- assertEquals(result, B.divi(2.0f, R));
+ @Test
+ public void testDiv() {
+ FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 2.0f, -2.666666666f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.divi(2.0f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.div(2.0f));
- }
+ // In-place, but independent operands
+ assertEquals(result, B.divi(C, R));
- public void testRdiv() {
- FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 2.0f, -2.666666666f);
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.divi(C, R));
- FloatMatrix R = new FloatMatrix();
+ // In-place on other
+ R = C.dup();
+ assertEquals(result, B.divi(R, R));
- // In-place, but independent operands
- assertEquals(result, C.rdivi(B, R));
+ // fully dynamic
+ assertEquals(result, B.div(C));
- // In-place on this
- R = C.dup();
- assertEquals(result, R.rdivi(B, R));
+ result = new FloatMatrix(3, 1, 1.0f, 2.0f, 4.0f);
- // In-place on other
- R = B.dup();
- assertEquals(result, C.rdivi(R, R));
+ // In-place, but independent operands
+ assertEquals(result, B.divi(2.0f, R));
- // fully dynamic
- assertEquals(result, C.rdiv(B));
+ // In-place on this
+ R = B.dup();
+ assertEquals(result, R.divi(2.0f, R));
- result = new FloatMatrix(3, 1, 0.5f, 0.25f, 0.125f);
+ // fully dynamic
+ assertEquals(result, B.div(2.0f));
+ }
- // In-place, but independent operands
- assertEquals(result, B.rdivi(1.0f, R));
+ @Test
+ public void testRdiv() {
+ FloatMatrix result = new FloatMatrix(3, 1, -2.0f, 2.0f, -2.666666666f);
- // In-place on this
- R = B.dup();
- assertEquals(result, R.rdivi(1.0f, R));
+ FloatMatrix R = new FloatMatrix();
- // fully dynamic
- assertEquals(result, B.rdiv(1.0f));
- }
+ // In-place, but independent operands
+ assertEquals(result, C.rdivi(B, R));
- /*# def test_logical(op, result, scalar, result2); <<-EOS
- public void test#{op.upcase}() {
- FloatMatrix result = new FloatMatrix(3, 1, #{result});
- FloatMatrix result2 = new FloatMatrix(3, 1, #{result2});
- FloatMatrix R = new FloatMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.#{op}i(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.#{op}i(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.#{op}i(F, R));
-
- // fully dynamic
- assertEquals(result, B.#{op}(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.#{op}i(#{scalar}, R));
-
- // fully dynamic
- assertEquals(result2, B.#{op}(#{scalar}));
- }
- EOS
- end
- #*/
- /*# test_logical('lt', '1.0f, 0.0f, 0.0f', 4.0f, '1.0f, 0.0f, 0.0f') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testLT() {
- FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 0.0f, 0.0f);
- FloatMatrix result2 = new FloatMatrix(3, 1, 1.0f, 0.0f, 0.0f);
- FloatMatrix R = new FloatMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.lti(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.lti(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.lti(F, R));
-
- // fully dynamic
- assertEquals(result, B.lt(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.lti(4.0f, R));
-
- // fully dynamic
- assertEquals(result2, B.lt(4.0f));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('le', '1.0f, 1.0f, 0.0f', 4.0f, '1.0f, 1.0f, 0.0f') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testLE() {
- FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 1.0f, 0.0f);
- FloatMatrix result2 = new FloatMatrix(3, 1, 1.0f, 1.0f, 0.0f);
- FloatMatrix R = new FloatMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.lei(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.lei(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.lei(F, R));
-
- // fully dynamic
- assertEquals(result, B.le(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.lei(4.0f, R));
-
- // fully dynamic
- assertEquals(result2, B.le(4.0f));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('gt', '0.0f, 0.0f, 1.0f', 4.0f, '0.0f, 0.0f, 1.0f') #*/
-//RJPP-BEGIN------------------------------------------------------------
- public void testGT() {
- FloatMatrix result = new FloatMatrix(3, 1, 0.0f, 0.0f, 1.0f);
- FloatMatrix result2 = new FloatMatrix(3, 1, 0.0f, 0.0f, 1.0f);
- FloatMatrix R = new FloatMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.gti(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.gti(R, R));
-
- // in-place in this
+ // In-place on this
+ R = C.dup();
+ assertEquals(result, R.rdivi(B, R));
+
+ // In-place on other
R = B.dup();
- assertEquals(result, R.gti(F, R));
-
+ assertEquals(result, C.rdivi(R, R));
+
// fully dynamic
- assertEquals(result, B.gt(F));
-
- // in-place in this
+ assertEquals(result, C.rdiv(B));
+
+ result = new FloatMatrix(3, 1, 0.5f, 0.25f, 0.125f);
+
+ // In-place, but independent operands
+ assertEquals(result, B.rdivi(1.0f, R));
+
+ // In-place on this
R = B.dup();
- assertEquals(result2, R.gti(4.0f, R));
-
+ assertEquals(result, R.rdivi(1.0f, R));
+
// fully dynamic
- assertEquals(result2, B.gt(4.0f));
- }
-//RJPP-END--------------------------------------------------------------
- /*# test_logical('ge', '0.0f, 1.0f, 1.0f', 4.0f, '0.0f, 1.0f, 1.0f') #*/
+ assertEquals(result, B.rdiv(1.0f));
+ }
+
+ /*# def test_logical(op, result, scalar, result2); <<-EOS
+ @Test
+ public void test#{op.upcase}() {
+ FloatMatrix result = new FloatMatrix(3, 1, #{result});
+ FloatMatrix result2 = new FloatMatrix(3, 1, #{result2});
+ FloatMatrix R = new FloatMatrix();
+
+ // in-place but independent operands
+ assertEquals(result, B.#{op}i(F, R));
+
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.#{op}i(R, R));
+
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.#{op}i(F, R));
+
+ // fully dynamic
+ assertEquals(result, B.#{op}(F));
+
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.#{op}i(#{scalar}, R));
+
+ // fully dynamic
+ assertEquals(result2, B.#{op}(#{scalar}));
+ }
+ EOS
+ end
+ #*/
+ /*# test_logical('lt', '1.0f, 0.0f, 0.0f', 4.0f, '1.0f, 0.0f, 0.0f') #*/
//RJPP-BEGIN------------------------------------------------------------
- public void testGE() {
- FloatMatrix result = new FloatMatrix(3, 1, 0.0f, 1.0f, 1.0f);
- FloatMatrix result2 = new FloatMatrix(3, 1, 0.0f, 1.0f, 1.0f);
- FloatMatrix R = new FloatMatrix();
-
- // in-place but independent operands
- assertEquals(result, B.gei(F, R));
-
- // in-place but in other
- R = F.dup();
- assertEquals(result, B.gei(R, R));
-
- // in-place in this
- R = B.dup();
- assertEquals(result, R.gei(F, R));
-
- // fully dynamic
- assertEquals(result, B.ge(F));
-
- // in-place in this
- R = B.dup();
- assertEquals(result2, R.gei(4.0f, R));
-
- // fully dynamic
- assertEquals(result2, B.ge(4.0f));
- }
-//RJPP-END--------------------------------------------------------------
- public void testMinMax() {
- assertEquals(1.0f, A.min());
- assertEquals(12.0f, A.max());
- }
+ @Test
+ public void testLT() {
+ FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 0.0f, 0.0f);
+ FloatMatrix result2 = new FloatMatrix(3, 1, 1.0f, 0.0f, 0.0f);
+ FloatMatrix R = new FloatMatrix();
- public void testArgMinMax() {
- assertEquals(0, A.argmin());
- assertEquals(11, A.argmax());
- }
+ // in-place but independent operands
+ assertEquals(result, B.lti(F, R));
- public void testTranspose() {
- FloatMatrix At = A.transpose();
- assertEquals(1.0f, At.get(0, 0));
- assertEquals(2.0f, At.get(0, 1));
- assertEquals(5.0f, At.get(1, 0));
- }
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.lti(R, R));
- public void testGetRowVector() {
- for (int r = 0; r < A.rows; r++) {
- A.getRow(r);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.lti(F, R));
- for (int c = 0; c < A.columns; c++) {
- A.getColumn(c);
- }
+ // fully dynamic
+ assertEquals(result, B.lt(F));
- A.addiRowVector(new FloatMatrix(3, 1, 10.0f, 100.0f, 1000.0f));
- A.addiColumnVector(new FloatMatrix(1, 4, 10.0f, 100.0f, 1000.0f, 10000.0f));
- }
-
- public void testPairwiseDistance() {
- FloatMatrix D = Geometry.pairwiseSquaredDistances(A, A);
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.lti(4.0f, R));
- FloatMatrix X = new FloatMatrix(1, 3, 1.0f, 0.0f, -1.0f);
+ // fully dynamic
+ assertEquals(result2, B.lt(4.0f));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('le', '1.0f, 1.0f, 0.0f', 4.0f, '1.0f, 1.0f, 0.0f') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testLE() {
+ FloatMatrix result = new FloatMatrix(3, 1, 1.0f, 1.0f, 0.0f);
+ FloatMatrix result2 = new FloatMatrix(3, 1, 1.0f, 1.0f, 0.0f);
+ FloatMatrix R = new FloatMatrix();
- Geometry.pairwiseSquaredDistances(X, X);
+ // in-place but independent operands
+ assertEquals(result, B.lei(F, R));
- FloatMatrix A1 = new FloatMatrix(1, 2, 1.0f, 2.0f);
- FloatMatrix A2 = new FloatMatrix(1, 3, 1.0f, 2.0f, 3.0f);
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.lei(R, R));
- Geometry.pairwiseSquaredDistances(A1, A2);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.lei(F, R));
- public void testSwapColumns() {
- FloatMatrix AA = A.dup();
+ // fully dynamic
+ assertEquals(result, B.le(F));
- AA.swapColumns(1, 2);
- assertEquals(new FloatMatrix(4, 3, 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f, 5.0f, 6.0f, 7.0f, 8.0f), AA);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.lei(4.0f, R));
- public void testSwapRows() {
- FloatMatrix AA = A.dup();
+ // fully dynamic
+ assertEquals(result2, B.le(4.0f));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('gt', '0.0f, 0.0f, 1.0f', 4.0f, '0.0f, 0.0f, 1.0f') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testGT() {
+ FloatMatrix result = new FloatMatrix(3, 1, 0.0f, 0.0f, 1.0f);
+ FloatMatrix result2 = new FloatMatrix(3, 1, 0.0f, 0.0f, 1.0f);
+ FloatMatrix R = new FloatMatrix();
- AA.swapRows(1, 2);
- assertEquals(new FloatMatrix(4, 3, 1.0f, 3.0f, 2.0f, 4.0f, 5.0f, 7.0f, 6.0f, 8.0f, 9.0f, 11.0f, 10.0f, 12.0f), AA);
- }
+ // in-place but independent operands
+ assertEquals(result, B.gti(F, R));
- public void testSolve() {
- FloatMatrix AA = new FloatMatrix(3, 3, 3.0f, 5.0f, 6.0f, 1.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f);
- FloatMatrix BB = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.gti(R, R));
- FloatMatrix Adup = AA.dup();
- FloatMatrix Bdup = BB.dup();
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.gti(F, R));
- FloatMatrix X = Solve.solve(AA, BB);
+ // fully dynamic
+ assertEquals(result, B.gt(F));
- assertEquals(Adup, AA);
- assertEquals(Bdup, BB);
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.gti(4.0f, R));
- public void testConstructFromArray() {
- float[][] data = {
- {1.0f, 2.0f, 3.0f},
- {4.0f, 5.0f, 6.0f},
- {7.0f, 8.0f, 9.0f}
- };
+ // fully dynamic
+ assertEquals(result2, B.gt(4.0f));
+ }
+//RJPP-END--------------------------------------------------------------
+ /*# test_logical('ge', '0.0f, 1.0f, 1.0f', 4.0f, '0.0f, 1.0f, 1.0f') #*/
+//RJPP-BEGIN------------------------------------------------------------
+ @Test
+ public void testGE() {
+ FloatMatrix result = new FloatMatrix(3, 1, 0.0f, 1.0f, 1.0f);
+ FloatMatrix result2 = new FloatMatrix(3, 1, 0.0f, 1.0f, 1.0f);
+ FloatMatrix R = new FloatMatrix();
- FloatMatrix A = new FloatMatrix(data);
+ // in-place but independent operands
+ assertEquals(result, B.gei(F, R));
- for (int r = 0; r < 3; r++) {
- for (int c = 0; c < 3; c++) {
- assertEquals(data[r][c], A.get(r, c));
- }
- }
- }
+ // in-place but in other
+ R = F.dup();
+ assertEquals(result, B.gei(R, R));
- public void testDiag() {
- FloatMatrix A = new FloatMatrix(new float[][]{
- {1.0f, 2.0f, 3.0f},
- {4.0f, 5.0f, 6.0f},
- {7.0f, 8.0f, 9.0f}
- });
+ // in-place in this
+ R = B.dup();
+ assertEquals(result, R.gei(F, R));
- assertEquals(new FloatMatrix(3, 1, 1.0f, 5.0f, 9.0f), A.diag());
+ // fully dynamic
+ assertEquals(result, B.ge(F));
- assertEquals(new FloatMatrix(new float[][]{
- {1.0f, 0.0f, 0.0f},
- {0.0f, 2.0f, 0.0f},
- {0.0f, 0.0f, 3.0f}
- }), FloatMatrix.diag(new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f)));
- }
+ // in-place in this
+ R = B.dup();
+ assertEquals(result2, R.gei(4.0f, R));
- public void testColumnAndRowMinMax() {
- assertEquals(new FloatMatrix(1, 3, 1.0f, 5.0f, 9.0f), A.columnMins());
- assertEquals(new FloatMatrix(4, 1, 1.0f, 2.0f, 3.0f, 4.0f), A.rowMins());
- assertEquals(new FloatMatrix(1, 3, 4.0f, 8.0f, 12.0f), A.columnMaxs());
- assertEquals(new FloatMatrix(4, 1, 9.0f, 10.0f, 11.0f, 12.0f), A.rowMaxs());
- int[] i = A.columnArgmins();
- assertEquals(0, i[0]);
- assertEquals(0, i[1]);
- assertEquals(0, i[2]);
- i = A.columnArgmaxs();
- assertEquals(3, i[0]);
- assertEquals(3, i[1]);
- assertEquals(3, i[2]);
- i = A.rowArgmins();
- assertEquals(0, i[0]);
- assertEquals(0, i[1]);
- assertEquals(0, i[2]);
- assertEquals(0, i[3]);
- i = A.rowArgmaxs();
- assertEquals(2, i[0]);
- assertEquals(2, i[1]);
- assertEquals(2, i[2]);
- assertEquals(2, i[3]);
- }
+ // fully dynamic
+ assertEquals(result2, B.ge(4.0f));
+ }
+//RJPP-END--------------------------------------------------------------
+ @Test
+ public void testMinMax() {
+ assertEquals(1.0f, A.min(), eps);
+ assertEquals(12.0f, A.max(), eps);
+ }
- public void testToArray() {
- assertTrue(Arrays.equals(new float[]{2.0f, 4.0f, 8.0f}, B.toArray()));
- assertTrue(Arrays.equals(new int[]{2, 4, 8}, B.toIntArray()));
- assertTrue(Arrays.equals(new boolean[]{true, true, true}, B.toBooleanArray()));
- }
+ @Test
+ public void testArgMinMax() {
+ assertEquals(0, A.argmin(), eps);
+ assertEquals(11, A.argmax(), eps);
+ }
- public void testLoadAsciiFile() {
- try {
- File f = File.createTempFile("jblas-test", "txt");
- f.deleteOnExit();
- PrintStream out = new PrintStream(f);
- out.println("1.0f 2.0f 3.0f");
- out.println("4.0f 5.0f 6.0f");
- out.close();
-
- FloatMatrix result = FloatMatrix.loadAsciiFile(f.getAbsolutePath());
- assertEquals(new FloatMatrix(2, 3, 1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f), result);
- } catch (Exception e) {
- fail("Caught exception " + e);
- }
- }
-
- public void testRanges() {
- // Hm... Broken?
- //System.out.printf("Ranges: %s\n", A.get(interval(0, 2), interval(0, 1)).toString());
- //assertEquals(new FloatMatrix(3, 2, 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f), );
- }
-}
+ @Test
+ public void testTranspose() {
+ FloatMatrix At = A.transpose();
+ assertEquals(1.0f, At.get(0, 0), eps);
+ assertEquals(2.0f, At.get(0, 1), eps);
+ assertEquals(5.0f, At.get(1, 0), eps);
+ }
+
+ @Test
+ public void testGetRowVector() {
+ for (int r = 0; r < A.rows; r++) {
+ A.getRow(r);
+ }
+
+ for (int c = 0; c < A.columns; c++) {
+ A.getColumn(c);
+ }
+
+ A.addiRowVector(new FloatMatrix(3, 1, 10.0f, 100.0f, 1000.0f));
+ A.addiColumnVector(new FloatMatrix(1, 4, 10.0f, 100.0f, 1000.0f, 10000.0f));
+ }
+
+ @Test
+ public void testPairwiseDistance() {
+ FloatMatrix D = Geometry.pairwiseSquaredDistances(A, A);
+
+ FloatMatrix X = new FloatMatrix(1, 3, 1.0f, 0.0f, -1.0f);
+
+ Geometry.pairwiseSquaredDistances(X, X);
+
+ FloatMatrix A1 = new FloatMatrix(1, 2, 1.0f, 2.0f);
+ FloatMatrix A2 = new FloatMatrix(1, 3, 1.0f, 2.0f, 3.0f);
+
+ Geometry.pairwiseSquaredDistances(A1, A2);
+ }
+
+ @Test
+ public void testSwapColumns() {
+ FloatMatrix AA = A.dup();
+
+ AA.swapColumns(1, 2);
+ assertEquals(new FloatMatrix(4, 3, 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f, 5.0f, 6.0f, 7.0f, 8.0f), AA);
+ }
+
+ @Test
+ public void testSwapRows() {
+ FloatMatrix AA = A.dup();
+
+ AA.swapRows(1, 2);
+ assertEquals(new FloatMatrix(4, 3, 1.0f, 3.0f, 2.0f, 4.0f, 5.0f, 7.0f, 6.0f, 8.0f, 9.0f, 11.0f, 10.0f, 12.0f), AA);
+ }
+
+ @Test
+ public void testSolve() {
+ FloatMatrix AA = new FloatMatrix(3, 3, 3.0f, 5.0f, 6.0f, 1.0f, 0.0f, 0.0f, 2.0f, 4.0f, 0.0f);
+ FloatMatrix BB = new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f);
+
+ FloatMatrix Adup = AA.dup();
+ FloatMatrix Bdup = BB.dup();
+
+ FloatMatrix X = Solve.solve(AA, BB);
+
+ assertEquals(Adup, AA);
+ assertEquals(Bdup, BB);
+ }
+
+ @Test
+ public void testConstructFromArray() {
+ float[][] data = {
+ {1.0f, 2.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f},
+ {7.0f, 8.0f, 9.0f}
+ };
+
+ FloatMatrix A = new FloatMatrix(data);
+
+ for (int r = 0; r < 3; r++) {
+ for (int c = 0; c < 3; c++) {
+ assertEquals(data[r][c], A.get(r, c), eps);
+ }
+ }
+ }
+
+ @Test
+ public void testDiag() {
+ FloatMatrix A = new FloatMatrix(new float[][]{
+ {1.0f, 2.0f, 3.0f},
+ {4.0f, 5.0f, 6.0f},
+ {7.0f, 8.0f, 9.0f}
+ });
+
+ assertEquals(new FloatMatrix(3, 1, 1.0f, 5.0f, 9.0f), A.diag());
+
+ assertEquals(new FloatMatrix(new float[][]{
+ {1.0f, 0.0f, 0.0f},
+ {0.0f, 2.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f}
+ }), FloatMatrix.diag(new FloatMatrix(3, 1, 1.0f, 2.0f, 3.0f)));
+ }
+
+ @Test
+ public void testColumnAndRowMinMax() {
+ assertEquals(new FloatMatrix(1, 3, 1.0f, 5.0f, 9.0f), A.columnMins());
+ assertEquals(new FloatMatrix(4, 1, 1.0f, 2.0f, 3.0f, 4.0f), A.rowMins());
+ assertEquals(new FloatMatrix(1, 3, 4.0f, 8.0f, 12.0f), A.columnMaxs());
+ assertEquals(new FloatMatrix(4, 1, 9.0f, 10.0f, 11.0f, 12.0f), A.rowMaxs());
+ int[] i = A.columnArgmins();
+ assertEquals(0, i[0]);
+ assertEquals(0, i[1]);
+ assertEquals(0, i[2]);
+ i = A.columnArgmaxs();
+ assertEquals(3, i[0]);
+ assertEquals(3, i[1]);
+ assertEquals(3, i[2]);
+ i = A.rowArgmins();
+ assertEquals(0, i[0]);
+ assertEquals(0, i[1]);
+ assertEquals(0, i[2]);
+ assertEquals(0, i[3]);
+ i = A.rowArgmaxs();
+ assertEquals(2, i[0]);
+ assertEquals(2, i[1]);
+ assertEquals(2, i[2]);
+ assertEquals(2, i[3]);
+ }
+
+ @Test
+ public void testToArray() {
+ assertTrue(Arrays.equals(new float[]{2.0f, 4.0f, 8.0f}, B.toArray()));
+ assertTrue(Arrays.equals(new int[]{2, 4, 8}, B.toIntArray()));
+ assertTrue(Arrays.equals(new boolean[]{true, true, true}, B.toBooleanArray()));
+ }
+
+ @Test
+ public void testLoadAsciiFile() {
+ try {
+ File f = File.createTempFile("jblas-test", "txt");
+ f.deleteOnExit();
+ PrintStream out = new PrintStream(f);
+ out.println("1.0f 2.0f 3.0f");
+ out.println("4.0f 5.0f 6.0f");
+ out.close();
+
+ FloatMatrix result = FloatMatrix.loadAsciiFile(f.getAbsolutePath());
+ assertEquals(new FloatMatrix(2, 3, 1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f), result);
+ } catch (Exception e) {
+ fail("Caught exception " + e);
+ }
+ }
+
+ @Test
+ public void testRanges() {
+ FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+ FloatMatrix B = new FloatMatrix(2, 3, -1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f);
+
+ A.put(interval(0, 2), interval(0, 3), B);
+
+ /*assertEquals(-1.0f, A.get(0, 0));
+ assertEquals(-2.0f, A.get(0, 1));
+ assertEquals(-3.0f, A.get(0, 2));
+ assertEquals(-4.0f, A.get(1, 0));
+ assertEquals(-5.0f, A.get(1, 1));
+ assertEquals(-6.0f, A.get(1, 2));*/
+ }
+
+ @Test
+ public void testRandWithSeed() {
+ Random.seed(1);
+ FloatMatrix A = FloatMatrix.rand(3, 3);
+ Random.seed(1);
+ FloatMatrix B = FloatMatrix.rand(3, 3);
+ assertEquals(0.0f, A.sub(B).normmax(), 1e-9);
+ }
+
+ @Test
+ public void testToString() {
+ // We have to be a bit cautious here because my Float => Float converter scripts will
+ // add a "f" to every floating point number, even in the strings. Therefore, I
+ // explicitly remove all "f"s
+ assertEquals("[1.000000f, 5.000000f, 9.000000f; 2.000000f, 6.000000f, 10.000000f; 3.000000f, 7.000000f, 11.000000f; 4.000000f, 8.000000f, 12.000000f]".replaceAll("f", ""), A.toString());
+
+ assertEquals("[1.0f, 5.0f, 9.0f; 2.0f, 6.0f, 10.0f; 3.0f, 7.0f, 11.0f; 4.0f, 8.0f, 12.0f]".replaceAll("f", ""), A.toString("%.1f"));
+
+ assertEquals("{1.0f 5.0f 9.0f; 2.0f 6.0f 10.0f; 3.0f 7.0f 11.0f; 4.0f 8.0f 12.0f}".replaceAll("f", ""), A.toString("%.1f", "{", "}", " ", "; "));
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/org/jblas/TestGeometry.java b/src/test/java/org/jblas/TestGeometry.java
index 557b590..fb19356 100644
--- a/src/test/java/org/jblas/TestGeometry.java
+++ b/src/test/java/org/jblas/TestGeometry.java
@@ -36,41 +36,49 @@
package org.jblas;
-import junit.framework.TestCase;
-
-public class TestGeometry extends TestCase {
- public void testCenter() {
- DoubleMatrix x = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
-
- Geometry.center(x);
-
- assertEquals(new DoubleMatrix(3, 1, -1.0, 0.0, 1.0), x);
-
- DoubleMatrix M = new DoubleMatrix(new double[][] {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
-
- //M.print();
-
- DoubleMatrix MR = Geometry.centerRows(M.dup());
- DoubleMatrix MC = Geometry.centerColumns(M.dup());
-
- //MR.print();
- //MC.print();
-
- assertEquals(new DoubleMatrix(new double[][] {{-1.0, 0.0, 1.0}, {-1.0, 0.0, 1.0}, {-1.0, 0.0, 1.0}}), MR);
- assertEquals(new DoubleMatrix(new double[][] {{-3.0, -3.0, -3.0}, {0.0, 0.0, 0.0}, {3.0, 3.0, 3.0}}), MC);
- }
-
- public void testPwDist() {
- DoubleMatrix M = new DoubleMatrix(3, 5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0);
-
- DoubleMatrix D = Geometry.pairwiseSquaredDistances(M, M);
-
- D.print();
-
- M = M.transpose();
-
- D = Geometry.pairwiseSquaredDistances(M, M);
-
- D.print();
- }
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+public class TestGeometry {
+
+ private final double eps = 1e-10;
+
+ @Test
+ public void testCenter() {
+ DoubleMatrix x = new DoubleMatrix(3, 1, 1.0, 2.0, 3.0);
+
+ Geometry.center(x);
+
+ assertEquals(new DoubleMatrix(3, 1, -1.0, 0.0, 1.0), x);
+
+ DoubleMatrix M = new DoubleMatrix(new double[][]{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}});
+
+ DoubleMatrix MR = Geometry.centerRows(M.dup());
+ DoubleMatrix MC = Geometry.centerColumns(M.dup());
+
+ assertEquals(new DoubleMatrix(new double[][]{{-1.0, 0.0, 1.0}, {-1.0, 0.0, 1.0}, {-1.0, 0.0, 1.0}}), MR);
+ assertEquals(new DoubleMatrix(new double[][]{{-3.0, -3.0, -3.0}, {0.0, 0.0, 0.0}, {3.0, 3.0, 3.0}}), MC);
+ }
+
+ @Test
+ public void testPwDist() {
+ DoubleMatrix M = new DoubleMatrix(3, 5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0);
+
+ DoubleMatrix D = Geometry.pairwiseSquaredDistances(M, M);
+
+ assertEquals(0.0, new DoubleMatrix(5, 5,
+ 0.0, 27.0, 108.0, 243.0, 432.0,
+ 27.0, 0.0, 27.0, 108.0, 243.0,
+ 108.0, 27.0, 0.0, 27.0, 108.0,
+ 243.0, 108.0, 27.0, 0.0, 27.0,
+ 432.0, 243.0, 108.0, 27.0, 0.0).distance2(D), eps);
+
+ M = M.transpose();
+
+ D = Geometry.pairwiseSquaredDistances(M, M);
+
+ assertEquals(0.0, new DoubleMatrix(3, 3,
+ 0.0, 5.0, 20.0, 5.0, 0.0, 5.0, 20.0, 5.0, 0.0).distance2(D), eps);
+ }
}
diff --git a/src/test/java/org/jblas/TestSingular.java b/src/test/java/org/jblas/TestSingular.java
new file mode 100644
index 0000000..9636ddc
--- /dev/null
+++ b/src/test/java/org/jblas/TestSingular.java
@@ -0,0 +1,81 @@
+package org.jblas;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+
+/**
+ * Test cases for class Singular
+ *
+ * Singular value decompositions
+ *
+ * @author Mikio L. Braun
+ */
+public class TestSingular {
+ @Test
+ public void testComplexDoubleSVD() {
+ ComplexDoubleMatrix A = new ComplexDoubleMatrix(3, 4);
+
+ for (int i = 0; i < A.rows; i++)
+ for (int j = 0; j < A.columns; j++)
+ A.put(i, j, (double) i, (double) j);
+
+ ComplexDoubleMatrix[] USV = Singular.sparseSVD(A);
+ ComplexDoubleMatrix U = USV[0];
+ ComplexDoubleMatrix S = ComplexDoubleMatrix.diag(USV[1]);
+ ComplexDoubleMatrix V = USV[2];
+
+ assertEquals(3, U.rows);
+ assertEquals(3, U.columns);
+ assertEquals(4, V.rows);
+ assertEquals(3, V.columns);
+ assertEquals(0.0, U.mmul(S).mmul(V.hermitian()).sub(A).normmax(), 1e-10);
+
+ USV = Singular.fullSVD(A);
+ U = USV[0];
+ S = ComplexDoubleMatrix.diag(USV[1], 3, 4);
+ V = USV[2];
+
+ assertEquals(3, U.rows);
+ assertEquals(3, U.columns);
+ assertEquals(3, S.rows);
+ assertEquals(4, S.columns);
+ assertEquals(4, V.rows);
+ assertEquals(4, V.columns);
+
+ assertEquals(0.0, U.mmul(S).mmul(V.hermitian()).sub(A).normmax(), 1e-10);
+ }
+
+ @Test
+ public void testComplexFloatSVD() {
+ ComplexFloatMatrix A = new ComplexFloatMatrix(3, 4);
+
+ for (int i = 0; i < A.rows; i++)
+ for (int j = 0; j < A.columns; j++)
+ A.put(i, j, (float) i, (float) j);
+
+ ComplexFloatMatrix[] USV = Singular.sparseSVD(A);
+ ComplexFloatMatrix U = USV[0];
+ ComplexFloatMatrix S = ComplexFloatMatrix.diag(USV[1]);
+ ComplexFloatMatrix V = USV[2];
+
+ assertEquals(3, U.rows);
+ assertEquals(3, U.columns);
+ assertEquals(4, V.rows);
+ assertEquals(3, V.columns);
+ assertEquals(0.0, U.mmul(S).mmul(V.hermitian()).sub(A).normmax(), 1e-4);
+
+ USV = Singular.fullSVD(A);
+ U = USV[0];
+ S = ComplexFloatMatrix.diag(USV[1], 3, 4);
+ V = USV[2];
+
+ assertEquals(3, U.rows);
+ assertEquals(3, U.columns);
+ assertEquals(3, S.rows);
+ assertEquals(4, S.columns);
+ assertEquals(4, V.rows);
+ assertEquals(4, V.columns);
+
+ assertEquals(0.0, U.mmul(S).mmul(V.hermitian()).sub(A).normmax(), 1e-4);
+ }
+}
diff --git a/src/test/java/org/jblas/TestSolve.java b/src/test/java/org/jblas/TestSolve.java
new file mode 100644
index 0000000..9d7b458
--- /dev/null
+++ b/src/test/java/org/jblas/TestSolve.java
@@ -0,0 +1,114 @@
+// --- BEGIN LICENSE BLOCK ---
+/*
+ * Copyright (c) 2009, Mikio L. Braun
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials provided
+ * with the distribution.
+ *
+ * * Neither the name of the Technische Universität Berlin nor the
+ * names of its contributors may be used to endorse or promote
+ * products derived from this software without specific prior
+ * written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+// --- END LICENSE BLOCK ---
+
+package org.jblas;
+
+import org.jblas.DoubleMatrix;
+import org.jblas.FloatMatrix;
+import org.jblas.Solve;
+import org.jblas.util.Random;
+import org.junit.*;
+import static org.junit.Assert.*;
+
+/**
+ * Tests for methods in Solve
+ *
+ * Created: 1/18/13, 12:10 PM
+ *
+ * @author Mikio L. Braun
+ */
+public class TestSolve {
+ @Test
+ public void testLeastSquaresDouble() {
+ DoubleMatrix A = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
+ DoubleMatrix B = new DoubleMatrix(3, 1, 3.0, 2.0, -1.0);
+
+ assertEquals(0.0, new DoubleMatrix(3, 1, -23.0 / 9, -2.0 / 3, 11.0 / 9).sub(Solve.solveLeastSquares(A, B)).normmax(), 1e-10);
+ }
+
+ @Test
+ public void testLeastSquaresFloat() {
+ FloatMatrix A = new FloatMatrix(3, 3, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f);
+ FloatMatrix B = new FloatMatrix(3, 1, 3.0f, 2.0f, -1.0f);
+
+ assertEquals(0.0f, new FloatMatrix(3, 1, -23.0f / 9, -2.0f / 3, 11.0f / 9).sub(Solve.solveLeastSquares(A, B)).normmax(), 1e-5f);
+ }
+
+ @Test
+ public void testLeastSquaresWideMatrixDouble() {
+ DoubleMatrix A = new DoubleMatrix(2, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
+ DoubleMatrix B = new DoubleMatrix(2, 1, 1.0, -1.0);
+
+ assertEquals(0.0, B.sub(A.mmul(Solve.solveLeastSquares(A, B))).normmax(), 1e-10);
+ }
+
+ @Test
+ public void testPinvDouble() {
+ DoubleMatrix A = new DoubleMatrix(3, 2, 1.0, 3.0, 5.0, 2.0, 4.0, 6.0);
+
+ DoubleMatrix pinvA = Solve.pinv(A);
+ assertEquals(0.0, A.mmul(pinvA).mmul(A).sub(A).normmax(), 1e-10);
+ assertEquals(0.0, pinvA.mmul(A).mmul(pinvA).sub(pinvA).normmax(), 1e-10);
+
+ DoubleMatrix At = A.transpose();
+ DoubleMatrix pinvAt = Solve.pinv(At);
+
+ assertEquals(0.0, At.mmul(pinvAt).mmul(At).sub(At).normmax(), 1e-10);
+ assertEquals(0.0, pinvAt.mmul(At).mmul(pinvAt).sub(pinvAt).normmax(), 1e-10);
+ }
+
+ @Test
+ public void testPinvFloat() {
+ FloatMatrix A = new FloatMatrix(3, 2, 1.0f, 3.0f, 5.0f, 2.0f, 4.0f, 6.0f);
+
+ FloatMatrix pinvA = Solve.pinv(A);
+ assertEquals(0.0f, A.mmul(pinvA).mmul(A).sub(A).normmax(), 1e-5f);
+ assertEquals(0.0f, pinvA.mmul(A).mmul(pinvA).sub(pinvA).normmax(), 1e-5f);
+ }
+
+// @Test
+// public void randomSized() {
+// Random r = new Random();
+// int m = 700; // r.nextInt(1000);
+// int n = 700; // r.nextInt(1000);
+// DoubleMatrix A = DoubleMatrix.rand(m, n);
+//
+// System.out.printf("Pinv for %d * %d matrix...\n", A.rows, A.columns);
+// double t = System.nanoTime();
+// Solve.pinv(A);
+// System.out.printf("Pinv for %d * %d matrix took %.1fs\n", A.rows, A.columns, (System.nanoTime() - t) / 1e9);
+// }
+}
diff --git a/src/test/java/org/jblas/ranges/RangeTest.java b/src/test/java/org/jblas/ranges/RangeTest.java
new file mode 100644
index 0000000..bc53444
--- /dev/null
+++ b/src/test/java/org/jblas/ranges/RangeTest.java
@@ -0,0 +1,44 @@
+package org.jblas.ranges;
+
+import org.jblas.DoubleMatrix;
+import org.junit.Test;
+import static org.jblas.ranges.RangeUtils.*;
+import static org.jblas.JblasAssert.*;
+
+/**
+ * Testing the ranges facility.
+ *
+ * @author Mikio L. Braun
+ * File created June 12, 2012, 16:04
+ */
+
+public class RangeTest {
+
+ private DoubleMatrix A = new DoubleMatrix(3, 4, 1.0, 2.0, 3.0,
+ 4.0, 5.0, 6.0,
+ 7.0, 8.0, 9.0,
+ 10.0, 11.0, 12.0);
+
+ @Test
+ public void allRange() {
+ assertEquals(A, A.get(all(), all()));
+ assertEquals(new DoubleMatrix(3, 1, 1.0, 2.0, 3.0), A.get(all(), 0));
+ assertEquals(new DoubleMatrix(3, 1, 7.0, 8.0, 9.0), A.get(all(), 2));
+ }
+
+ @Test
+ public void pointRange() {
+ assertEquals(DoubleMatrix.scalar(8.0), A.get(point(1), point(2)));
+ }
+
+ @Test
+ public void indicesRange() {
+ assertEquals(new DoubleMatrix(1, 5, 1.0, 7.0, 4.0, 10.0, 1.0), A.get(0, indices(new int[] {0, 2, 1, 3, 0})));
+ }
+
+ @Test
+ public void mixedRanges() {
+ assertEquals(new DoubleMatrix(3, 2, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0), A.get(all(), interval(1, 3)));
+ assertEquals(new DoubleMatrix(2, 1, 11.0, 12.0), A.get(interval(1, 3), point(3)));
+ }
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/pkg-java/jblas.git
More information about the pkg-java-commits
mailing list